Commit 6f514024 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Refactor the XLNet pretrain data generation script(s).

PiperOrigin-RevId: 343108621
parent 713142d1
This diff is collapsed.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.create_xlnet_pretraining_data."""
import os
import tempfile
from typing import List
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.data import create_xlnet_pretraining_data as cpd
_VOCAB_WORDS = ["vocab_1", "vocab_2"]
# pylint: disable=invalid-name
def _create_files(
temp_dir: str, file_contents: List[List[str]]) -> List[str]:
"""Writes arbitrary documents into files."""
root_dir = tempfile.mkdtemp(dir=temp_dir)
files = []
for i, file_content in enumerate(file_contents):
destination = os.path.join(root_dir, "%d.txt" % i)
with open(destination, "wb") as f:
for line in file_content:
f.write(line.encode("utf-8"))
files.append(destination)
return files
def _get_mock_tokenizer():
"""Creates a mock tokenizer."""
class MockSpieceModel:
"""Mock Spiece model for testing."""
def __init__(self):
self._special_piece_to_id = {
"<unk>": 0,
}
for piece in set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')):
self._special_piece_to_id[piece] = 1
def EncodeAsPieces(self, inputs: str) -> List[str]:
return inputs
def SampleEncodeAsPieces(self,
inputs: str,
nbest_size: int,
theta: float) -> List[str]:
del nbest_size, theta
return inputs
def PieceToId(self, piece: str) -> int:
return ord(piece[0])
def IdToPiece(self, id_: int) -> str:
return chr(id_) * 3
class Tokenizer:
"""Mock Tokenizer for testing."""
def __init__(self):
self.sp_model = MockSpieceModel()
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
return [self.sp_model.IdToPiece(id_) for id_ in ids]
return Tokenizer()
class PreprocessDataTest(tf.test.TestCase):
def test_remove_extraneous_space(self):
line = " abc "
output = cpd._preprocess_line(line)
self.assertEqual(output, "abc")
def test_symbol_replacements(self):
self.assertEqual(cpd._preprocess_line("``abc``"), "\"abc\"")
self.assertEqual(cpd._preprocess_line("''abc''"), "\"abc\"")
def test_accent_replacements(self):
self.assertEqual(cpd._preprocess_line("åbc"), "abc")
def test_lower_case(self):
self.assertEqual(cpd._preprocess_line("ABC", do_lower_case=True), "abc")
def test_end_to_end(self):
self.assertEqual(
cpd._preprocess_line("HelLo ``wórLd``", do_lower_case=True),
"hello \"world\"")
class PreprocessAndTokenizeFilesTest(tf.test.TestCase):
def test_basic_end_to_end(self):
documents = [
[
"This is sentence 1.\n",
"This is sentence 2.\n",
"Sentence 3 is what this is.\n",
],
[
"This is the second document.\n",
"This is the second line of the second document.\n"
],
]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
self.assertEqual(len(all_data), len(documents))
for token_ids, sentence_ids in all_data:
self.assertEqual(len(token_ids), len(sentence_ids))
def test_basic_correctness(self):
documents = [["a\n", "b\n", "c\n"]]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
token_ids, sentence_ids = all_data[0]
self.assertAllClose(token_ids, [97, 98, 99])
self.assertAllClose(sentence_ids, [True, False, True])
def test_correctness_with_spaces_and_accents(self):
documents = [[
" å \n",
"b \n",
" c \n",
]]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
token_ids, sentence_ids = all_data[0]
self.assertAllClose(token_ids, [97, 98, 99])
self.assertAllClose(sentence_ids, [True, False, True])
class BatchReshapeTests(tf.test.TestCase):
def test_basic_functionality(self):
per_host_batch_size = 3
mock_shape = (20,)
# Should truncate and reshape.
expected_result_shape = (3, 6)
tokens = np.zeros(mock_shape)
sentence_ids = np.zeros(mock_shape)
reshaped_data = cpd._reshape_to_batch_dimensions(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=per_host_batch_size)
for values in reshaped_data:
self.assertEqual(len(values.flatten()) % per_host_batch_size, 0)
self.assertAllClose(values.shape, expected_result_shape)
class CreateSegmentsTest(tf.test.TestCase):
def test_basic_functionality(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.concatenate([np.zeros(data_length // 2),
np.ones(data_length // 2)])
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=0.)
self.assertAllClose(a_data, [0, 1, 2, 3])
self.assertAllClose(b_data, [5, 6, 7, 8])
self.assertEqual(label, 1)
def test_no_cut(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.zeros(data_length)
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=0.)
self.assertGreater(len(a_data), 0)
self.assertGreater(len(b_data), 0)
self.assertEqual(label, 0)
def test_no_cut_with_probability(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.concatenate([np.zeros(data_length // 2),
np.ones(data_length // 2)])
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=1.)
self.assertGreater(len(a_data), 0)
self.assertGreater(len(b_data), 0)
self.assertEqual(label, 0)
class CreateInstancesTest(tf.test.TestCase):
"""Tests conversions of Token/Sentence IDs to training instances."""
def test_basic(self):
data_length = 12
tokens = np.arange(data_length)
sentence_ids = np.zeros(data_length)
seq_length = 8
instances = cpd._convert_tokens_to_instances(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=2,
seq_length=seq_length,
reuse_length=4,
tokenizer=_get_mock_tokenizer(),
bi_data=False,
num_cores_per_host=1,
logging_frequency=1)
for instance in instances:
self.assertEqual(len(instance.data), seq_length)
self.assertEqual(len(instance.segment_ids), seq_length)
self.assertIsInstance(instance.label, int)
self.assertIsInstance(instance.boundary_indices, list)
class TFRecordPathTests(tf.test.TestCase):
def test_basic(self):
base_kwargs = dict(
per_host_batch_size=1,
num_cores_per_host=1,
seq_length=2,
reuse_length=1)
config1 = dict(
prefix="test",
suffix="",
bi_data=True,
use_eod_token=False,
do_lower_case=True)
config1.update(base_kwargs)
expectation1 = "test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config1), expectation1)
config2 = dict(
prefix="",
suffix="test",
bi_data=False,
use_eod_token=False,
do_lower_case=False)
config2.update(base_kwargs)
expectation2 = "seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config2), expectation2)
config3 = dict(
prefix="",
suffix="",
use_eod_token=True,
bi_data=False,
do_lower_case=True)
config3.update(base_kwargs)
expectation3 = "seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config3), expectation3)
class TestCreateTFRecords(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
("bi_data_only", True, False, False),
("eod_token_only", False, True, True),
("lower_case_only", False, False, True),
("all_enabled", True, True, True),
)
def test_end_to_end(self,
bi_data: bool,
use_eod_token: bool,
do_lower_case: bool):
tokenizer = _get_mock_tokenizer()
num_documents = 5
sentences_per_document = 10
document_length = 50
documents = [
["a " * document_length for _ in range(sentences_per_document)]
for _ in range(num_documents)]
save_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
files = _create_files(temp_dir=self.get_temp_dir(), file_contents=documents)
cpd.create_tfrecords(
tokenizer=tokenizer,
input_file_or_files=",".join(files),
use_eod_token=use_eod_token,
do_lower_case=do_lower_case,
per_host_batch_size=8,
seq_length=8,
reuse_length=4,
bi_data=bi_data,
num_cores_per_host=2,
save_dir=save_dir)
self.assertTrue(any(filter(lambda x: x.endswith(".json"),
os.listdir(save_dir))))
self.assertTrue(any(filter(lambda x: x.endswith(".tfrecord"),
os.listdir(save_dir))))
if __name__ == "__main__":
np.random.seed(0)
logging.set_verbosity(logging.INFO)
tf.test.main()
......@@ -143,8 +143,7 @@ class XLNetPretrainDataConfig(cfg.DataConfig):
reuse_length: The number of tokens in a previous segment to reuse. This
should be the same value used during pretrain data creation.
sample_strategy: The strategy used to sample factorization permutations.
Possible values: 'fixed', 'single_token', 'whole_word', 'token_span',
'word_span'.
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
min_num_tokens: The minimum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
max_num_tokens: The maximum number of tokens to sample in a span.
......@@ -208,12 +207,8 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_type_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'target':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'boundary_indices':
tf.io.VarLenFeature(tf.int64),
'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
example = tf.io.parse_single_example(record, name_to_features)
......@@ -234,20 +229,12 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
inputs = record['input_word_ids']
x['input_type_ids'] = record['input_type_ids']
if self._sample_strategy == 'fixed':
input_mask = record['input_mask']
else:
input_mask = None
if self._sample_strategy in ['whole_word', 'word_span']:
boundary = tf.sparse.to_dense(record['boundary_indices'])
else:
boundary = None
input_mask = self._online_sample_mask(
inputs=inputs,
input_mask=input_mask,
boundary=boundary)
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
if self._reuse_length > 0:
if self._permutation_size > self._reuse_length:
......@@ -503,14 +490,10 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
def _online_sample_mask(self,
inputs: tf.Tensor,
input_mask: tf.Tensor,
boundary: tf.Tensor) -> tf.Tensor:
"""Samples target positions for predictions.
Descriptions of each strategy:
- 'fixed': Returns the input mask that was computed during pretrain data
creation. The value for `max_predictions_per_seq` must match the value
used during dataset creation.
- 'single_token': Samples individual tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
......@@ -518,9 +501,6 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
Args:
inputs: The input tokens.
input_mask: The `bool` Tensor of the same shape as `inputs`. This is the
input mask calculated when creating pretraining the pretraining dataset.
If `sample_strategy` is not 'fixed', this is not used.
boundary: The `int` Tensor of indices indicating whole word boundaries.
This is used in 'whole_word' and 'word_span'
......@@ -528,26 +508,17 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
The sampled `bool` input mask.
Raises:
`ValueError`: if `max_predictions_per_seq` is not set
and the sample strategy is not 'fixed', or if boundary is not provided
for 'whole_word' and 'word_span' sample strategies.
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
not provided for 'whole_word' and 'word_span' sample strategies.
"""
if (self._sample_strategy != 'fixed' and
self._max_predictions_per_seq is None):
raise ValueError(
'`max_predictions_per_seq` must be set if using '
'sample strategy {}.'.format(self._sample_strategy))
if self._max_predictions_per_seq is None:
raise ValueError('`max_predictions_per_seq` must be set.')
if boundary is None and 'word' in self._sample_strategy:
raise ValueError('`boundary` must be provided for {} strategy'.format(
self._sample_strategy))
if self._sample_strategy == 'fixed':
# Uses the computed input masks from preprocessing.
# Note: This should have `max_predictions_per_seq` number of tokens set
# to 1.
return tf.cast(input_mask, tf.bool)
elif self._sample_strategy == 'single_token':
if self._sample_strategy == 'single_token':
return self._single_token_mask(inputs)
elif self._sample_strategy == 'token_span':
return self._token_span_mask(inputs)
......
......@@ -174,7 +174,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(itertools.product(
("fixed", "single_token", "whole_word", "token_span"),
("single_token", "whole_word", "token_span"),
(0, 64),
(20, None),
))
......@@ -200,9 +200,8 @@ class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
permutation_size=seq_length // 2,
leak_ratio=0.1)
if (max_predictions_per_seq is None and sample_strategy != "fixed"):
with self.assertRaisesWithRegexpMatch(
ValueError, "`max_predictions_per_seq` must be set"):
if max_predictions_per_seq is None:
with self.assertRaises(ValueError):
dataset = pretrain_dataloader.XLNetPretrainDataLoader(
data_config).load()
features = next(iter(dataset))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment