Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.create_xlnet_pretraining_data."""
import os
import tempfile
from typing import List
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.data import create_xlnet_pretraining_data as cpd
_VOCAB_WORDS = ["vocab_1", "vocab_2"]
# pylint: disable=invalid-name
def _create_files(
temp_dir: str, file_contents: List[List[str]]) -> List[str]:
"""Writes arbitrary documents into files."""
root_dir = tempfile.mkdtemp(dir=temp_dir)
files = []
for i, file_content in enumerate(file_contents):
destination = os.path.join(root_dir, "%d.txt" % i)
with open(destination, "wb") as f:
for line in file_content:
f.write(line.encode("utf-8"))
files.append(destination)
return files
def _get_mock_tokenizer():
"""Creates a mock tokenizer."""
class MockSpieceModel:
"""Mock Spiece model for testing."""
def __init__(self):
self._special_piece_to_id = {
"<unk>": 0,
}
for piece in set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')):
self._special_piece_to_id[piece] = 1
def EncodeAsPieces(self, inputs: str) -> List[str]:
return inputs
def SampleEncodeAsPieces(self,
inputs: str,
nbest_size: int,
theta: float) -> List[str]:
del nbest_size, theta
return inputs
def PieceToId(self, piece: str) -> int:
return ord(piece[0])
def IdToPiece(self, id_: int) -> str:
return chr(id_) * 3
class Tokenizer:
"""Mock Tokenizer for testing."""
def __init__(self):
self.sp_model = MockSpieceModel()
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
return [self.sp_model.IdToPiece(id_) for id_ in ids]
return Tokenizer()
class PreprocessDataTest(tf.test.TestCase):
def test_remove_extraneous_space(self):
line = " abc "
output = cpd._preprocess_line(line)
self.assertEqual(output, "abc")
def test_symbol_replacements(self):
self.assertEqual(cpd._preprocess_line("``abc``"), "\"abc\"")
self.assertEqual(cpd._preprocess_line("''abc''"), "\"abc\"")
def test_accent_replacements(self):
self.assertEqual(cpd._preprocess_line("åbc"), "abc")
def test_lower_case(self):
self.assertEqual(cpd._preprocess_line("ABC", do_lower_case=True), "abc")
def test_end_to_end(self):
self.assertEqual(
cpd._preprocess_line("HelLo ``wórLd``", do_lower_case=True),
"hello \"world\"")
class PreprocessAndTokenizeFilesTest(tf.test.TestCase):
def test_basic_end_to_end(self):
documents = [
[
"This is sentence 1.\n",
"This is sentence 2.\n",
"Sentence 3 is what this is.\n",
],
[
"This is the second document.\n",
"This is the second line of the second document.\n"
],
]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
self.assertEqual(len(all_data), len(documents))
for token_ids, sentence_ids in all_data:
self.assertEqual(len(token_ids), len(sentence_ids))
def test_basic_correctness(self):
documents = [["a\n", "b\n", "c\n"]]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
token_ids, sentence_ids = all_data[0]
self.assertAllClose(token_ids, [97, 98, 99])
self.assertAllClose(sentence_ids, [True, False, True])
def test_correctness_with_spaces_and_accents(self):
documents = [[
" å \n",
"b \n",
" c \n",
]]
input_files = _create_files(temp_dir=self.get_temp_dir(),
file_contents=documents)
all_data = cpd.preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=_get_mock_tokenizer(),
log_example_freq=1)
token_ids, sentence_ids = all_data[0]
self.assertAllClose(token_ids, [97, 98, 99])
self.assertAllClose(sentence_ids, [True, False, True])
class BatchReshapeTests(tf.test.TestCase):
def test_basic_functionality(self):
per_host_batch_size = 3
mock_shape = (20,)
# Should truncate and reshape.
expected_result_shape = (3, 6)
tokens = np.zeros(mock_shape)
sentence_ids = np.zeros(mock_shape)
reshaped_data = cpd._reshape_to_batch_dimensions(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=per_host_batch_size)
for values in reshaped_data:
self.assertEqual(len(values.flatten()) % per_host_batch_size, 0)
self.assertAllClose(values.shape, expected_result_shape)
class CreateSegmentsTest(tf.test.TestCase):
def test_basic_functionality(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.concatenate([np.zeros(data_length // 2),
np.ones(data_length // 2)])
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=0.)
self.assertAllClose(a_data, [0, 1, 2, 3])
self.assertAllClose(b_data, [5, 6, 7, 8])
self.assertEqual(label, 1)
def test_no_cut(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.zeros(data_length)
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=0.)
self.assertGreater(len(a_data), 0)
self.assertGreater(len(b_data), 0)
self.assertEqual(label, 0)
def test_no_cut_with_probability(self):
data_length = 10
tokens = np.arange(data_length)
sentence_ids = np.concatenate([np.zeros(data_length // 2),
np.ones(data_length // 2)])
begin_index = 0
total_length = 8
a_data, b_data, label = cpd._create_a_and_b_segments(
tokens=tokens,
sentence_ids=sentence_ids,
begin_index=begin_index,
total_length=total_length,
no_cut_probability=1.)
self.assertGreater(len(a_data), 0)
self.assertGreater(len(b_data), 0)
self.assertEqual(label, 0)
class CreateInstancesTest(tf.test.TestCase):
"""Tests conversions of Token/Sentence IDs to training instances."""
def test_basic(self):
data_length = 12
tokens = np.arange(data_length)
sentence_ids = np.zeros(data_length)
seq_length = 8
instances = cpd._convert_tokens_to_instances(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=2,
seq_length=seq_length,
reuse_length=4,
tokenizer=_get_mock_tokenizer(),
bi_data=False,
num_cores_per_host=1,
logging_frequency=1)
for instance in instances:
self.assertEqual(len(instance.data), seq_length)
self.assertEqual(len(instance.segment_ids), seq_length)
self.assertIsInstance(instance.label, int)
self.assertIsInstance(instance.boundary_indices, list)
class TFRecordPathTests(tf.test.TestCase):
def test_basic(self):
base_kwargs = dict(
per_host_batch_size=1,
num_cores_per_host=1,
seq_length=2,
reuse_length=1)
config1 = dict(
prefix="test",
suffix="",
bi_data=True,
use_eod_token=False,
do_lower_case=True)
config1.update(base_kwargs)
expectation1 = "test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config1), expectation1)
config2 = dict(
prefix="",
suffix="test",
bi_data=False,
use_eod_token=False,
do_lower_case=False)
config2.update(base_kwargs)
expectation2 = "seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config2), expectation2)
config3 = dict(
prefix="",
suffix="",
use_eod_token=True,
bi_data=False,
do_lower_case=True)
config3.update(base_kwargs)
expectation3 = "seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
self.assertEqual(cpd.get_tfrecord_name(**config3), expectation3)
class TestCreateTFRecords(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
("bi_data_only", True, False, False),
("eod_token_only", False, True, True),
("lower_case_only", False, False, True),
("all_enabled", True, True, True),
)
def test_end_to_end(self,
bi_data: bool,
use_eod_token: bool,
do_lower_case: bool):
tokenizer = _get_mock_tokenizer()
num_documents = 5
sentences_per_document = 10
document_length = 50
documents = [
["a " * document_length for _ in range(sentences_per_document)]
for _ in range(num_documents)]
save_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
files = _create_files(temp_dir=self.get_temp_dir(), file_contents=documents)
cpd.create_tfrecords(
tokenizer=tokenizer,
input_file_or_files=",".join(files),
use_eod_token=use_eod_token,
do_lower_case=do_lower_case,
per_host_batch_size=8,
seq_length=8,
reuse_length=4,
bi_data=bi_data,
num_cores_per_host=2,
save_dir=save_dir)
self.assertTrue(any(filter(lambda x: x.endswith(".json"),
os.listdir(save_dir))))
self.assertTrue(any(filter(lambda x: x.endswith(".tfrecord"),
os.listdir(save_dir))))
if __name__ == "__main__":
np.random.seed(0)
logging.set_verbosity(logging.INFO)
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An abstraction that NLP models define input pipelines."""
import abc
from typing import Optional
import tensorflow as tf
class DataLoader(metaclass=abc.ABCMeta):
"""An abstract class defining the APIs for tf.data input pipeline."""
@abc.abstractmethod
def load(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Implements DataLoader load method.
Builds the entire input pipeline inside the load method. Users can define
states inside the DataLoader class and returns a tf.data dataset
object.
Args:
input_context: This is a context class that is passed to the user's input
function and contains information about the compute replicas and input
pipelines. This object is used for multi-host inputs and passed by the
distribution strategy.
Returns:
A per-host tf.data dataset. Note that, we usually create the distributed
dataset through the load method, so we should not directly return a
distributed dataset here.
"""
pass
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,10 +11,10 @@ ...@@ -12,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders.""" """A global factory to access NLP registered data loaders."""
from official.utils import registry from official.core import registry
_REGISTERED_DATA_LOADER_CLS = {} _REGISTERED_DATA_LOADER_CLS = {}
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.data_loader_factory."""
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class MyDataConfig(cfg.DataConfig):
is_training: bool = True
@data_loader_factory.register_data_loader_cls(MyDataConfig)
class MyDataLoader:
def __init__(self, params):
self.params = params
class DataLoaderFactoryTest(tf.test.TestCase):
def test_register_and_load(self):
train_config = MyDataConfig()
train_loader = data_loader_factory.get_data_loader(train_config)
self.assertTrue(train_loader.params.is_training)
if __name__ == "__main__":
tf.test.main()
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,15 +11,18 @@ ...@@ -12,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Loads dataset for the BERT pretraining task.""" """Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional from typing import Mapping, Optional
from absl import logging
import dataclasses import dataclasses
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -34,10 +36,16 @@ class BertPretrainDataConfig(cfg.DataConfig): ...@@ -34,10 +36,16 @@ class BertPretrainDataConfig(cfg.DataConfig):
max_predictions_per_seq: int = 76 max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True use_next_sentence_label: bool = True
use_position_id: bool = False use_position_id: bool = False
# Historically, BERT implementations take `input_ids` and `segment_ids` as
# feature names. Inside the TF Model Garden implementation, the Keras model
# inputs are set as `input_word_ids` and `input_type_ids`. When
# v2_feature_names is True, the data loader assumes the tf.Examples use
# `input_word_ids` and `input_type_ids` as keys.
use_v2_feature_names: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig) @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader(data_loader.DataLoader):
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
def __init__(self, params): def __init__(self, params):
...@@ -52,15 +60,10 @@ class BertPretrainDataLoader: ...@@ -52,15 +60,10 @@ class BertPretrainDataLoader:
self._use_next_sentence_label = params.use_next_sentence_label self._use_next_sentence_label = params.use_next_sentence_label
self._use_position_id = params.use_position_id self._use_position_id = params.use_position_id
def _decode(self, record: tf.Tensor): def _name_to_features(self):
"""Decodes a serialized tf.Example."""
name_to_features = { name_to_features = {
'input_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': 'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64), tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'masked_lm_positions': 'masked_lm_positions':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
'masked_lm_ids': 'masked_lm_ids':
...@@ -68,13 +71,27 @@ class BertPretrainDataLoader: ...@@ -68,13 +71,27 @@ class BertPretrainDataLoader:
'masked_lm_weights': 'masked_lm_weights':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32), tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
} }
if self._params.use_v2_feature_names:
name_to_features.update({
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
})
else:
name_to_features.update({
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
})
if self._use_next_sentence_label: if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64) tf.int64)
if self._use_position_id: if self._use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature( name_to_features['position_ids'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64) [self._seq_length], tf.int64)
return name_to_features
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = self._name_to_features()
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...@@ -90,13 +107,17 @@ class BertPretrainDataLoader: ...@@ -90,13 +107,17 @@ class BertPretrainDataLoader:
def _parse(self, record: Mapping[str, tf.Tensor]): def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x = { x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'], 'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'], 'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'], 'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'], 'masked_lm_weights': record['masked_lm_weights'],
} }
if self._params.use_v2_feature_names:
x['input_word_ids'] = record['input_word_ids']
x['input_type_ids'] = record['input_type_ids']
else:
x['input_word_ids'] = record['input_ids']
x['input_type_ids'] = record['segment_ids']
if self._use_next_sentence_label: if self._use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels'] x['next_sentence_labels'] = record['next_sentence_labels']
if self._use_position_id: if self._use_position_id:
...@@ -109,3 +130,475 @@ class BertPretrainDataLoader: ...@@ -109,3 +130,475 @@ class BertPretrainDataLoader:
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
@dataclasses.dataclass
class XLNetPretrainDataConfig(cfg.DataConfig):
"""Data config for XLNet pretraining task.
Attributes:
input_path: See base class.
global_batch_size: See base calss.
is_training: See base class.
seq_length: The length of each sequence.
max_predictions_per_seq: The number of predictions per sequence.
reuse_length: The number of tokens in a previous segment to reuse. This
should be the same value used during pretrain data creation.
sample_strategy: The strategy used to sample factorization permutations.
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
min_num_tokens: The minimum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
max_num_tokens: The maximum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
min_num_words: The minimum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
max_num_words: The maximum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
permutation_size: The length of the longest permutation. This can be set
to `reuse_length`. This should NOT be greater than `reuse_length`,
otherwise this may introduce data leaks.
leak_ratio: The percentage of masked tokens that are leaked.
segment_sep_id: The ID of the SEP token used when preprocessing
the dataset.
segment_cls_id: The ID of the CLS token used when preprocessing
the dataset.
"""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
reuse_length: int = 256
sample_strategy: str = 'word_span'
min_num_tokens: int = 1
max_num_tokens: int = 5
min_num_words: int = 1
max_num_words: int = 5
permutation_size: int = 256
leak_ratio: float = 0.1
segment_sep_id: int = 4
segment_cls_id: int = 3
@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
class XLNetPretrainDataLoader(data_loader.DataLoader):
"""A class to load dataset for xlnet pretraining task."""
def __init__(self, params: XLNetPretrainDataConfig):
"""Inits `XLNetPretrainDataLoader` class.
Args:
params: A `XLNetPretrainDataConfig` object.
"""
self._params = params
self._seq_length = params.seq_length
self._max_predictions_per_seq = params.max_predictions_per_seq
self._reuse_length = params.reuse_length
self._num_replicas_in_sync = None
self._permutation_size = params.permutation_size
self._sep_id = params.segment_sep_id
self._cls_id = params.segment_cls_id
self._sample_strategy = params.sample_strategy
self._leak_ratio = params.leak_ratio
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_word_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_type_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'boundary_indices':
tf.io.VarLenFeature(tf.int64),
}
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {}
inputs = record['input_word_ids']
x['input_type_ids'] = record['input_type_ids']
if self._sample_strategy in ['whole_word', 'word_span']:
boundary = tf.sparse.to_dense(record['boundary_indices'])
else:
boundary = None
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
if self._reuse_length > 0:
if self._permutation_size > self._reuse_length:
logging.warning(
'`permutation_size` is greater than `reuse_length` (%d > %d).'
'This may introduce data leakage.',
self._permutation_size, self._reuse_length)
# Enable the memory mechanism.
# Permute the reuse and non-reuse segments separately.
non_reuse_len = self._seq_length - self._reuse_length
if not (self._reuse_length % self._permutation_size == 0
and non_reuse_len % self._permutation_size == 0):
raise ValueError('`reuse_length` and `seq_length` should both be '
'a multiple of `permutation_size`.')
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
inputs=inputs[:self._reuse_length],
input_mask=input_mask[:self._reuse_length])
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
inputs[self._reuse_length:], input_mask[self._reuse_length:])
perm_mask_0 = tf.concat(
[perm_mask_0,
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)],
axis=1)
perm_mask_1 = tf.concat(
[tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
perm_mask_1], axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
tokens = tf.concat([tokens_0, tokens_1], axis=0)
masked_tokens = tf.concat([masked_0, masked_1], axis=0)
else:
# Disable the memory mechanism.
if self._seq_length % self._permutation_size != 0:
raise ValueError('`seq_length` should be a multiple of '
'`permutation_size`.')
# Permute the entire sequence together
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
inputs=inputs, input_mask=input_mask)
x['permutation_mask'] = tf.reshape(
perm_mask, [self._seq_length, self._seq_length])
x['input_word_ids'] = tokens
x['masked_tokens'] = masked_tokens
target = tokens
if self._max_predictions_per_seq is not None:
indices = tf.range(self._seq_length, dtype=tf.int32)
bool_target_mask = tf.cast(target_mask, tf.bool)
indices = tf.boolean_mask(indices, bool_target_mask)
# account for extra padding due to CLS/SEP.
actual_num_predict = tf.shape(indices)[0]
pad_len = self._max_predictions_per_seq - actual_num_predict
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
paddings = tf.zeros([pad_len, self._seq_length],
dtype=target_mapping.dtype)
target_mapping = tf.concat([target_mapping, paddings], axis=0)
x['target_mapping'] = tf.reshape(
target_mapping, [self._max_predictions_per_seq, self._seq_length])
target = tf.boolean_mask(target, bool_target_mask)
paddings = tf.zeros([pad_len], dtype=target.dtype)
target = tf.concat([target, paddings], axis=0)
x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.int32),
tf.zeros([pad_len], dtype=tf.int32)
], axis=0)
x['target_mask'] = tf.reshape(target_mask,
[self._max_predictions_per_seq])
else:
x['target'] = tf.reshape(target, [self._seq_length])
x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
return x
def _index_pair_to_mask(self,
begin_indices: tf.Tensor,
end_indices: tf.Tensor,
inputs: tf.Tensor) -> tf.Tensor:
"""Converts beginning and end indices into an actual mask."""
non_func_mask = tf.logical_and(
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
all_indices = tf.where(
non_func_mask,
tf.range(self._seq_length, dtype=tf.int32),
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
candidate_matrix = tf.cast(
tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
all_indices[None, :] < end_indices[:, None]), tf.float32)
cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
return tf.cast(target_mask, tf.bool)
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
"""Samples individual tokens as prediction targets."""
all_indices = tf.range(self._seq_length, dtype=tf.int32)
non_func_mask = tf.logical_and(
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices)
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
sparse_indices = tf.stack(
[tf.zeros_like(masked_pos), masked_pos], axis=-1)
sparse_indices = tf.cast(sparse_indices, tf.int64)
sparse_indices = tf.sparse.SparseTensor(
sparse_indices,
values=tf.ones_like(masked_pos),
dense_shape=(1, self._seq_length))
target_mask = tf.sparse.to_dense(
sp_input=sparse_indices,
default_value=0)
return tf.squeeze(tf.cast(target_mask, tf.bool))
def _whole_word_mask(self,
inputs: tf.Tensor,
boundary: tf.Tensor) -> tf.Tensor:
"""Samples whole words as prediction targets."""
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
cand_pair_indices = tf.random.shuffle(
pair_indices)[:self._max_predictions_per_seq]
begin_indices = cand_pair_indices[:, 0]
end_indices = cand_pair_indices[:, 1]
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
"""Samples token spans as prediction targets."""
min_num_tokens = self._params.min_num_tokens
max_num_tokens = self._params.max_num_tokens
mask_alpha = self._seq_length / self._max_predictions_per_seq
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=self._max_predictions_per_seq,
dtype=tf.int32,
)[0] + min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
# Compute the offset from left start to the right end
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices
begin_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = begin_indices + span_lens
# Remove out of range indices
valid_idx_mask = end_indices < self._seq_length
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
begin_indices = tf.gather(begin_indices, order)
end_indices = tf.gather(end_indices, order)
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _word_span_mask(self,
inputs: tf.Tensor,
boundary: tf.Tensor):
"""Sample whole word spans as prediction targets."""
min_num_words = self._params.min_num_words
max_num_words = self._params.max_num_words
# Note: 1.2 is the token-to-word ratio
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
# Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32)
# Sample `num_predict` words here: note that this is over sampling
span_lens = tf.random.categorical(
logits=logits[None],
num_samples=self._max_predictions_per_seq,
dtype=tf.int32,
)[0] + min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float = tf.cast(span_lens, tf.float32)
left_ratio = tf.random.uniform(
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
begin_indices = (
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = begin_indices + span_lens
# Remove out of range indices
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
valid_idx_mask = end_indices < max_boundary_index
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
begin_indices = tf.gather(boundary, begin_indices)
end_indices = tf.gather(boundary, end_indices)
# Shuffle valid indices
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
begin_indices = tf.gather(begin_indices, order)
end_indices = tf.gather(end_indices, order)
return self._index_pair_to_mask(
begin_indices=begin_indices,
end_indices=end_indices,
inputs=inputs)
def _online_sample_mask(self,
inputs: tf.Tensor,
boundary: tf.Tensor) -> tf.Tensor:
"""Samples target positions for predictions.
Descriptions of each strategy:
- 'single_token': Samples individual tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
- 'word_span': Samples spans of words as prediction targets.
Args:
inputs: The input tokens.
boundary: The `int` Tensor of indices indicating whole word boundaries.
This is used in 'whole_word' and 'word_span'
Returns:
The sampled `bool` input mask.
Raises:
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
not provided for 'whole_word' and 'word_span' sample strategies.
"""
if self._max_predictions_per_seq is None:
raise ValueError('`max_predictions_per_seq` must be set.')
if boundary is None and 'word' in self._sample_strategy:
raise ValueError('`boundary` must be provided for {} strategy'.format(
self._sample_strategy))
if self._sample_strategy == 'single_token':
return self._single_token_mask(inputs)
elif self._sample_strategy == 'token_span':
return self._token_span_mask(inputs)
elif self._sample_strategy == 'whole_word':
return self._whole_word_mask(inputs, boundary)
elif self._sample_strategy == 'word_span':
return self._word_span_mask(inputs, boundary)
else:
raise NotImplementedError('Invalid sample strategy.')
def _get_factorization(self,
inputs: tf.Tensor,
input_mask: tf.Tensor):
"""Samples a permutation of the factorization order.
Args:
inputs: the input tokens.
input_mask: the `bool` Tensor of the same shape as `inputs`.
If `True`, then this means select for partial prediction.
Returns:
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
token (in original order) cannot attend to the jth attention token.
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
If target_mask[i] == 1, then the i-th token needs to be predicted and
the mask will be used as input. This token will be included in the loss.
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
input. This token will not be included in the loss.
tokens: int32 Tensor of shape [seq_length].
masked_tokens: int32 Tensor of shape [seq_length].
"""
factorization_length = tf.shape(inputs)[0]
# Generate permutation indices
index = tf.range(factorization_length, dtype=tf.int32)
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
input_mask = tf.cast(input_mask, tf.bool)
# non-functional tokens
non_func_tokens = tf.logical_not(
tf.logical_or(
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
masked_tokens = tf.logical_and(input_mask, non_func_tokens)
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
# Similar to BERT, randomly leak some masked tokens
if self._leak_ratio > 0:
leak_tokens = tf.logical_and(
masked_tokens,
tf.random.uniform([factorization_length],
maxval=1.0) < self._leak_ratio)
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
else:
can_attend_self = non_masked_or_func_tokens
to_index = tf.where(can_attend_self, smallest_index, index)
from_index = tf.where(can_attend_self, to_index + 1, to_index)
# For masked tokens, can attend if i > j
# For context tokens, always can attend each other
can_attend = from_index[:, None] > to_index[None, :]
perm_mask = tf.cast(can_attend, tf.int32)
# Only masked tokens are included in the loss
target_mask = tf.cast(masked_tokens, tf.int32)
return perm_mask, target_mask, inputs, masked_tokens
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
if input_context:
self._num_replicas_in_sync = input_context.num_replicas_in_sync
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.pretrain_dataloader."""
import itertools
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.data import pretrain_dataloader
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def _create_fake_bert_dataset(
output_path,
seq_length,
max_predictions_per_seq,
use_position_id,
use_next_sentence_label,
use_v2_feature_names=False):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
for _ in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features["input_mask"] = create_int_feature(np.ones_like(input_ids))
if use_v2_feature_names:
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
else:
features["input_ids"] = create_int_feature(input_ids)
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
features["masked_lm_positions"] = create_int_feature(
np.random.randint(100, size=(max_predictions_per_seq)))
features["masked_lm_ids"] = create_int_feature(
np.random.randint(100, size=(max_predictions_per_seq)))
features["masked_lm_weights"] = create_float_feature(
[1.0] * max_predictions_per_seq)
if use_next_sentence_label:
features["next_sentence_labels"] = create_int_feature([1])
if use_position_id:
features["position_ids"] = create_int_feature(range(0, seq_length))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _create_fake_xlnet_dataset(
output_path, seq_length, max_predictions_per_seq):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
for _ in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
num_boundary_indices = np.random.randint(1, seq_length)
if max_predictions_per_seq is not None:
input_mask = np.zeros_like(input_ids)
input_mask[:max_predictions_per_seq] = 1
np.random.shuffle(input_mask)
else:
input_mask = np.ones_like(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
features["boundary_indices"] = create_int_feature(
sorted(np.random.randint(seq_length, size=(num_boundary_indices))))
features["target"] = create_int_feature(input_ids + 1)
features["label"] = create_int_feature([1])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(itertools.product(
(False, True),
(False, True),
))
def test_load_data(self, use_next_sentence_label, use_position_id):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
max_predictions_per_seq = 20
_create_fake_bert_dataset(
train_data_path,
seq_length,
max_predictions_per_seq,
use_next_sentence_label=use_next_sentence_label,
use_position_id=use_position_id)
data_config = pretrain_dataloader.BertPretrainDataConfig(
input_path=train_data_path,
max_predictions_per_seq=max_predictions_per_seq,
seq_length=seq_length,
global_batch_size=10,
is_training=True,
use_next_sentence_label=use_next_sentence_label,
use_position_id=use_position_id)
dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
features = next(iter(dataset))
self.assertLen(features,
6 + int(use_next_sentence_label) + int(use_position_id))
self.assertIn("input_word_ids", features)
self.assertIn("input_mask", features)
self.assertIn("input_type_ids", features)
self.assertIn("masked_lm_positions", features)
self.assertIn("masked_lm_ids", features)
self.assertIn("masked_lm_weights", features)
self.assertEqual("next_sentence_labels" in features,
use_next_sentence_label)
self.assertEqual("position_ids" in features, use_position_id)
def test_v2_feature_names(self):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
max_predictions_per_seq = 20
_create_fake_bert_dataset(
train_data_path,
seq_length,
max_predictions_per_seq,
use_next_sentence_label=True,
use_position_id=False,
use_v2_feature_names=True)
data_config = pretrain_dataloader.BertPretrainDataConfig(
input_path=train_data_path,
max_predictions_per_seq=max_predictions_per_seq,
seq_length=seq_length,
global_batch_size=10,
is_training=True,
use_next_sentence_label=True,
use_position_id=False,
use_v2_feature_names=True)
dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
features = next(iter(dataset))
self.assertIn("input_word_ids", features)
self.assertIn("input_mask", features)
self.assertIn("input_type_ids", features)
self.assertIn("masked_lm_positions", features)
self.assertIn("masked_lm_ids", features)
self.assertIn("masked_lm_weights", features)
class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(itertools.product(
("single_token", "whole_word", "token_span"),
(0, 64),
(20, None),
))
def test_load_data(
self, sample_strategy, reuse_length, max_predictions_per_seq):
train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
seq_length = 128
batch_size = 5
_create_fake_xlnet_dataset(
train_data_path, seq_length, max_predictions_per_seq)
data_config = pretrain_dataloader.XLNetPretrainDataConfig(
input_path=train_data_path,
max_predictions_per_seq=max_predictions_per_seq,
seq_length=seq_length,
global_batch_size=batch_size,
is_training=True,
reuse_length=reuse_length,
sample_strategy=sample_strategy,
min_num_tokens=1,
max_num_tokens=2,
permutation_size=seq_length // 2,
leak_ratio=0.1)
if max_predictions_per_seq is None:
with self.assertRaises(ValueError):
dataset = pretrain_dataloader.XLNetPretrainDataLoader(
data_config).load()
features = next(iter(dataset))
else:
dataset = pretrain_dataloader.XLNetPretrainDataLoader(data_config).load()
features = next(iter(dataset))
self.assertIn("input_word_ids", features)
self.assertIn("input_type_ids", features)
self.assertIn("permutation_mask", features)
self.assertIn("masked_tokens", features)
self.assertIn("target", features)
self.assertIn("target_mask", features)
self.assertAllClose(features["input_word_ids"].shape,
(batch_size, seq_length))
self.assertAllClose(features["input_type_ids"].shape,
(batch_size, seq_length))
self.assertAllClose(features["permutation_mask"].shape,
(batch_size, seq_length, seq_length))
self.assertAllClose(features["masked_tokens"].shape,
(batch_size, seq_length,))
if max_predictions_per_seq is not None:
self.assertIn("target_mapping", features)
self.assertAllClose(features["target_mapping"].shape,
(batch_size, max_predictions_per_seq, seq_length))
self.assertAllClose(features["target_mask"].shape,
(batch_size, max_predictions_per_seq))
self.assertAllClose(features["target"].shape,
(batch_size, max_predictions_per_seq))
else:
self.assertAllClose(features["target_mask"].shape,
(batch_size, seq_length))
self.assertAllClose(features["target"].shape,
(batch_size, seq_length))
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset loader for the pre-training with dynamic sequence length."""
from typing import Optional, Tuple
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader_factory
from official.nlp.data import pretrain_dataloader
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
# TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
# tf.data service is disabled. Deprecate this flag once we always enable round
# robin tf.data service.
seq_bucket_window_scale: int = 8
use_next_sentence_label: bool = True
use_position_id: bool = False
deterministic: bool = False
enable_tf_data_service: bool = False
enable_round_robin_tf_data_service: bool = False
tf_data_service_job_name: str = 'bert_pretrain'
use_v2_feature_names: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
"""Dataset loader for bert-style pretraining with dynamic sequenece length.
Bucketizes the input id features by the seq_bucket_lengths and features are
padded to the bucket boundaries. The mask features are usually short than
input id features and can also be dynamic. We require the mask feature lengths
within a bucket must be the same. For example, with [128, 256] buckets,
the mask features for bucket 128 should always have the length as X and
features for bucket 256 should always have the length as Y.
The dataloader does not filter out empty masks. Make sure to handle this
in the model.
"""
def __init__(self, params):
self._params = params
if len(params.seq_bucket_lengths) < 1:
raise ValueError('The seq_bucket_lengths cannot be empty.')
self._seq_bucket_lengths = params.seq_bucket_lengths
self._seq_bucket_window_scale = params.seq_bucket_window_scale
self._global_batch_size = params.global_batch_size
self._use_next_sentence_label = params.use_next_sentence_label
self._use_position_id = params.use_position_id
self._drop_remainder = params.drop_remainder
self._enable_tf_data_service = params.enable_tf_data_service
self._enable_round_robin_tf_data_service = (
params.enable_round_robin_tf_data_service)
self._mask_keys = [
'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
]
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.VarLenFeature(tf.int64),
'input_mask': tf.io.VarLenFeature(tf.int64),
'segment_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
}
if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
dynamic_keys = ['input_ids', 'input_mask', 'segment_ids']
if self._use_position_id:
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
dynamic_keys.append('position_ids')
example = tf.io.parse_single_example(record, name_to_features)
for key in dynamic_keys + self._mask_keys:
example[key] = tf.sparse.to_dense(example[key])
# Truncate padded data after the first non pad in the
# sequence length dimension.
# Pad before the first non pad from the back should not be removed.
mask = tf.math.greater(
tf.math.cumsum(example['input_ids'], reverse=True), 0)
for key in dynamic_keys:
example[key] = tf.boolean_mask(example[key], mask)
# masked_lm_ids should be 0 padded.
# Change mask features to -1 padding so that we can differentiate
# padding from data or from bucketizing.
mask = tf.math.not_equal(example['masked_lm_ids'], 0)
example['masked_lm_ids'] = tf.where(
mask, example['masked_lm_ids'],
-tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
# tf.data service uses dataset graph fingerprint to distinguish input
# pipeline jobs, thus we sort the keys here to make sure they are generated
# in a deterministic order each time the dataset function is traced.
for name in sorted(list(example.keys())):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _bucketize_and_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
"""Bucketize by sequence length and batch the datasets."""
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
def element_length_func(example, seq_len_dim):
return tf.shape(example['input_word_ids'])[seq_len_dim]
bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
# Bucketize and batch the dataset with per replica batch size first.
dataset = dataset.apply(
tf.data.experimental.bucket_by_sequence_length(
lambda example: tf.cast(element_length_func(example, 0), tf.int32),
bucket_boundaries,
bucket_batch_sizes,
pad_to_bucket_boundary=True,
drop_remainder=self._drop_remainder))
if input_context:
window_size = input_context.num_replicas_in_sync
if self._enable_tf_data_service and (
not self._enable_round_robin_tf_data_service):
# If tf.data service is enabled but round-robin behavior is not enabled,
# different TPU workers may fetch data from one tf.data service worker
# in different speed. We set the window size to be
# `seq_bucket_window_scale` larger to leave buffer if some workers are
# fetching data faster than others, so all the data within the same
# global batch can still have more chances to be in the same bucket.
window_size *= self._seq_bucket_window_scale
# Group `num_replicas_in_sync` batches from same bucket together, so all
# replicas can get the same sequence length for one global step.
dataset = dataset.apply(
tf.data.experimental.group_by_window(
key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
element_length_func(example, 1), tf.int64),
reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
window_size=window_size))
dataset = dataset.flat_map(lambda x: x)
def _remove_pads_from_bucketize(features):
# All mask features must have the same effective length.
# The real masked ids padding token is -1 and 0 comes from
# bucket_by_sequence_length.
mask = tf.math.not_equal(features['masked_lm_ids'], 0)
mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
normalized = tf.cast(
mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
assert_op = tf.debugging.assert_equal(
tf.math.reduce_sum(normalized), per_replica_batch_size,
'Number of non padded mask tokens is not the same for each example '
'in the same sequence length.')
with tf.control_dependencies([assert_op]):
for key in self._mask_keys:
features[key] = tf.reshape(
tf.boolean_mask(
features[key], mask), [per_replica_batch_size, -1])
# Revert masked_lm_ids to be 0-padded.
mask = tf.math.not_equal(features['masked_lm_ids'], -1)
features['masked_lm_ids'] = tf.where(
mask, features['masked_lm_ids'],
tf.zeros(
tf.shape(features['masked_lm_ids']),
dtype=features['masked_lm_ids'].dtype))
return features
dataset = dataset.map(_remove_pads_from_bucketize)
return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse,
transform_and_batch_fn=self._bucketize_and_batch)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.data.pretrain_dynamic_dataloader."""
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.data import pretrain_dynamic_dataloader
from official.nlp.tasks import masked_lm
def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
max_seq_length, num_examples):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
for _ in range(num_examples):
features = {}
padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
input_ids = np.random.randint(low=1, high=100, size=(seq_length))
features['input_ids'] = create_int_feature(
np.concatenate((input_ids, padding)))
features['input_mask'] = create_int_feature(
np.concatenate((np.ones_like(input_ids), padding)))
features['segment_ids'] = create_int_feature(
np.concatenate((np.ones_like(input_ids), padding)))
features['position_ids'] = create_int_feature(
np.concatenate((np.ones_like(input_ids), padding)))
features['masked_lm_positions'] = create_int_feature(
np.random.randint(60, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_ids'] = create_int_feature(
np.random.randint(100, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_weights'] = create_float_feature(
np.ones((num_masked_tokens,), dtype=np.float32))
features['next_sentence_labels'] = create_int_feature(np.array([0]))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.cloud_tpu_strategy,
],
mode='eager'))
def test_distribution_strategy(self, distribution_strategy):
max_seq_length = 128
batch_size = 8
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(
input_path,
seq_length=60,
num_masked_tokens=20,
max_seq_length=max_seq_length,
num_examples=batch_size)
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
is_training=False,
input_path=input_path,
seq_bucket_lengths=[64, 128],
global_batch_size=batch_size)
dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
data_config)
distributed_ds = orbit.utils.make_distributed_dataset(
distribution_strategy, dataloader.load)
train_iter = iter(distributed_ds)
with distribution_strategy.scope():
config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
model=bert.PretrainerConfig(
encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(
vocab_size=30522, num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name='next_sentence')
]),
train_data=data_config)
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
metrics = task.build_metrics()
@tf.function
def step_fn(features):
return task.validation_step(features, model, metrics=metrics)
distributed_outputs = distribution_strategy.run(
step_fn, args=(next(train_iter),))
local_results = tf.nest.map_structure(
distribution_strategy.experimental_local_results, distributed_outputs)
logging.info('Dynamic padding: local_results= %s', str(local_results))
dynamic_metrics = {}
for metric in metrics:
dynamic_metrics[metric.name] = metric.result()
data_config = pretrain_dataloader.BertPretrainDataConfig(
is_training=False,
input_path=input_path,
seq_length=max_seq_length,
max_predictions_per_seq=20,
global_batch_size=batch_size)
dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
distributed_ds = orbit.utils.make_distributed_dataset(
distribution_strategy, dataloader.load)
train_iter = iter(distributed_ds)
with distribution_strategy.scope():
metrics = task.build_metrics()
@tf.function
def step_fn_b(features):
return task.validation_step(features, model, metrics=metrics)
distributed_outputs = distribution_strategy.run(
step_fn_b, args=(next(train_iter),))
local_results = tf.nest.map_structure(
distribution_strategy.experimental_local_results, distributed_outputs)
logging.info('Static padding: local_results= %s', str(local_results))
static_metrics = {}
for metric in metrics:
static_metrics[metric.name] = metric.result()
for key in static_metrics:
# We need to investigate the differences on losses.
if key != 'next_sentence_loss':
self.assertEqual(dynamic_metrics[key], static_metrics[key])
def test_load_dataset(self):
max_seq_length = 128
batch_size = 2
input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
_create_fake_dataset(
input_path_1,
seq_length=60,
num_masked_tokens=20,
max_seq_length=max_seq_length,
num_examples=batch_size)
input_path_2 = os.path.join(self.get_temp_dir(), 'train_2.tf_record')
_create_fake_dataset(
input_path_2,
seq_length=100,
num_masked_tokens=70,
max_seq_length=max_seq_length,
num_examples=batch_size)
input_paths = ','.join([input_path_1, input_path_2])
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
is_training=False,
input_path=input_paths,
seq_bucket_lengths=[64, 128],
use_position_id=True,
global_batch_size=batch_size)
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
data_config).load()
dataset_it = iter(dataset)
features = next(dataset_it)
self.assertCountEqual([
'input_word_ids',
'input_mask',
'input_type_ids',
'next_sentence_labels',
'masked_lm_positions',
'masked_lm_ids',
'masked_lm_weights',
'position_ids',
], features.keys())
# Sequence length dimension should be bucketized and pad to 64.
self.assertEqual(features['input_word_ids'].shape, (batch_size, 64))
self.assertEqual(features['input_mask'].shape, (batch_size, 64))
self.assertEqual(features['input_type_ids'].shape, (batch_size, 64))
self.assertEqual(features['position_ids'].shape, (batch_size, 64))
self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 20))
features = next(dataset_it)
self.assertEqual(features['input_word_ids'].shape, (batch_size, 128))
self.assertEqual(features['input_mask'].shape, (batch_size, 128))
self.assertEqual(features['input_type_ids'].shape, (batch_size, 128))
self.assertEqual(features['position_ids'].shape, (batch_size, 128))
self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 70))
def test_load_dataset_not_same_masks(self):
max_seq_length = 128
batch_size = 2
input_path_1 = os.path.join(self.get_temp_dir(), 'train_3.tf_record')
_create_fake_dataset(
input_path_1,
seq_length=60,
num_masked_tokens=20,
max_seq_length=max_seq_length,
num_examples=batch_size)
input_path_2 = os.path.join(self.get_temp_dir(), 'train_4.tf_record')
_create_fake_dataset(
input_path_2,
seq_length=60,
num_masked_tokens=15,
max_seq_length=max_seq_length,
num_examples=batch_size)
input_paths = ','.join([input_path_1, input_path_2])
data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
is_training=False,
input_path=input_paths,
seq_bucket_lengths=[64, 128],
use_position_id=True,
global_batch_size=batch_size * 2)
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
data_config).load()
dataset_it = iter(dataset)
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError, '.*Number of non padded mask tokens.*'):
next(dataset_it)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,20 +11,23 @@ ...@@ -12,20 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task.""" """Loads dataset for the question answering (e.g, SQuAD) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
@dataclasses.dataclass @dataclasses.dataclass
class QADataConfig(cfg.DataConfig): class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering).""" """Data config for question answering task (tasks/question_answering)."""
# For training, `input_path` is expected to be a pre-processed TFRecord file,
# while for evaluation, it is expected to be a raw JSON file (b/173814590).
input_path: str = '' input_path: str = ''
global_batch_size: int = 48 global_batch_size: int = 48
is_training: bool = True is_training: bool = True
...@@ -36,19 +38,23 @@ class QADataConfig(cfg.DataConfig): ...@@ -36,19 +38,23 @@ class QADataConfig(cfg.DataConfig):
input_preprocessed_data_path: str = '' input_preprocessed_data_path: str = ''
doc_stride: int = 128 doc_stride: int = 128
query_length: int = 64 query_length: int = 64
# The path to the vocab file of word piece tokenizer or the
# model of the sentence piece tokenizer.
vocab_file: str = '' vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True do_lower_case: bool = True
xlnet_format: bool = False
@data_loader_factory.register_data_loader_cls(QADataConfig) @data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader: class QuestionAnsweringDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._is_training = params.is_training self._is_training = params.is_training
self._xlnet_format = params.xlnet_format
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -57,6 +63,13 @@ class QuestionAnsweringDataLoader: ...@@ -57,6 +63,13 @@ class QuestionAnsweringDataLoader:
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
} }
if self._xlnet_format:
name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64)
if self._is_training:
name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
if self._is_training: if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
...@@ -78,7 +91,7 @@ class QuestionAnsweringDataLoader: ...@@ -78,7 +91,7 @@ class QuestionAnsweringDataLoader:
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {} x, y = {}, {}
for name, tensor in record.items(): for name, tensor in record.items():
if name in ('start_positions', 'end_positions'): if name in ('start_positions', 'end_positions', 'is_impossible'):
y[name] = tensor y[name] = tensor
elif name == 'input_ids': elif name == 'input_ids':
x['input_word_ids'] = tensor x['input_word_ids'] = tensor
...@@ -86,6 +99,8 @@ class QuestionAnsweringDataLoader: ...@@ -86,6 +99,8 @@ class QuestionAnsweringDataLoader:
x['input_type_ids'] = tensor x['input_type_ids'] = tensor
else: else:
x[name] = tensor x[name] = tensor
if name == 'start_positions' and self._xlnet_format:
x[name] = tensor
return (x, y) return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.question_answering_dataloader."""
import os
import numpy as np
import tensorflow as tf
from official.nlp.data import question_answering_dataloader
def _create_fake_dataset(output_path, seq_length):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
for _ in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features['input_ids'] = create_int_feature(input_ids)
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
features['start_positions'] = create_int_feature(np.array([0]))
features['end_positions'] = create_int_feature(np.array([10]))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class QuestionAnsweringDataTest(tf.test.TestCase):
def test_load_dataset(self):
seq_length = 128
batch_size = 10
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(input_path, seq_length)
data_config = question_answering_dataloader.QADataConfig(
is_training=True,
input_path=input_path,
seq_length=seq_length,
global_batch_size=batch_size)
dataset = question_answering_dataloader.QuestionAnsweringDataLoader(
data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertCountEqual(['start_positions', 'end_positions'], labels.keys())
self.assertEqual(labels['start_positions'].shape, (batch_size,))
self.assertEqual(labels['end_positions'].shape, (batch_size,))
if __name__ == '__main__':
tf.test.main()
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,16 +11,24 @@ ...@@ -12,16 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional import functools
from typing import List, Mapping, Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub
from official.common import dataset_fn
from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.nlp import modeling
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig): class SentencePredictionDataConfig(cfg.DataConfig):
...@@ -30,24 +37,32 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -30,24 +37,32 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size: int = 32 global_batch_size: int = 32
is_training: bool = True is_training: bool = True
seq_length: int = 128 seq_length: int = 128
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._include_example_id = params.include_example_id
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
label_type = LABEL_TYPES_MAP[self._params.label_type]
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64), 'label_ids': tf.io.FixedLenFeature([], label_type),
} }
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features) example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...@@ -67,6 +82,9 @@ class SentencePredictionDataLoader: ...@@ -67,6 +82,9 @@ class SentencePredictionDataLoader:
'input_mask': record['input_mask'], 'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'] 'input_type_ids': record['segment_ids']
} }
if self._include_example_id:
x['example_id'] = record['example_id']
y = record['label_ids'] y = record['label_ids']
return (x, y) return (x, y)
...@@ -75,3 +93,147 @@ class SentencePredictionDataLoader: ...@@ -75,3 +93,147 @@ class SentencePredictionDataLoader:
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
@dataclasses.dataclass
class SentencePredictionTextDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task with raw text."""
# Either set `input_path`...
input_path: str = ''
# Either `int` or `float`.
label_type: str = 'int'
# ...or `tfds_name` and `tfds_split` to specify input.
tfds_name: str = ''
tfds_split: str = ''
# The name of the text feature fields. The text features will be
# concatenated in order.
text_fields: Optional[List[str]] = None
label_field: str = 'label'
global_batch_size: int = 32
seq_length: int = 128
is_training: bool = True
# Either build preprocessing with Python code by specifying these values
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
# file if tokenization is SentencePiece.
vocab_file: str = ''
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ''
# Either tfrecord or sstsable or recordio.
file_type: str = 'tfrecord'
include_example_id: bool = False
class TextProcessor(tf.Module):
"""Text features processing for sentence prediction task."""
def __init__(self,
seq_length: int,
vocab_file: Optional[str] = None,
tokenization: Optional[str] = None,
lower_case: Optional[bool] = True,
preprocessing_hub_module_url: Optional[str] = None):
if preprocessing_hub_module_url:
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
self._tokenizer = self._preprocessing_hub_module.tokenize
self._pack_inputs = functools.partial(
self._preprocessing_hub_module.bert_pack_inputs,
seq_length=seq_length)
return
if tokenization == 'WordPiece':
self._tokenizer = modeling.layers.BertTokenizer(
vocab_file=vocab_file, lower_case=lower_case)
elif tokenization == 'SentencePiece':
self._tokenizer = modeling.layers.SentencepieceTokenizer(
model_file_path=vocab_file, lower_case=lower_case,
strip_diacritics=True) # Strip diacritics to follow ALBERT model
else:
raise ValueError('Unsupported tokenization: %s' % tokenization)
self._pack_inputs = modeling.layers.BertPackInputs(
seq_length=seq_length,
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
def __call__(self, segments):
segments = [self._tokenizer(s) for s in segments]
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
# and SentencepieceTokenizer returns a RaggedTensor with shape
# [batch, sentencepiece],
segments = [
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
for x in segments
]
return self._pack_inputs(segments)
@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Loads dataset with raw text for sentence prediction task."""
def __init__(self, params):
if bool(params.tfds_name) != bool(params.tfds_split):
raise ValueError('`tfds_name` and `tfds_split` should be specified or '
'unspecified at the same time.')
if bool(params.tfds_name) == bool(params.input_path):
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
'or `input_path`.')
if not params.text_fields:
raise ValueError('Unexpected empty text fields.')
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
raise ValueError('Must specify exactly one of vocab_file (with matching '
'lower_case flag) or preprocessing_hub_module_url.')
self._params = params
self._text_fields = params.text_fields
self._label_field = params.label_field
self._label_type = params.label_type
self._include_example_id = params.include_example_id
self._text_processor = TextProcessor(
seq_length=params.seq_length,
vocab_file=params.vocab_file,
tokenization=params.tokenization,
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
"""Berts preprocess."""
segments = [record[x] for x in self._text_fields]
model_inputs = self._text_processor(segments)
if self._include_example_id:
model_inputs['example_id'] = record['example_id']
y = record[self._label_field]
return model_inputs, y
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {}
for text_field in self._text_fields:
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
label_type = LABEL_TYPES_MAP[self._label_type]
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
decoder_fn=self._decode if self._params.input_path else None,
params=self._params,
postprocess_fn=self._bert_preprocess)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.sentence_prediction_dataloader."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.data import sentence_prediction_dataloader as loader
def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
for _ in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features['input_ids'] = create_int_feature(input_ids)
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
if label_type == 'int':
features['label_ids'] = create_int_feature([1])
elif label_type == 'float':
features['label_ids'] = create_float_feature([0.5])
else:
raise ValueError('Unsupported label_type: %s' % label_type)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _create_fake_raw_dataset(output_path, text_fields, label_type):
"""Creates a fake tf record file."""
writer = tf.io.TFRecordWriter(output_path)
def create_str_feature(value):
f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
return f
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f
for _ in range(100):
features = {}
for text_field in text_fields:
features[text_field] = create_str_feature([b'hello world'])
if label_type == 'int':
features['label'] = create_int_feature([0])
elif label_type == 'float':
features['label'] = create_float_feature([0.5])
else:
raise ValueError('Unexpected label_type: %s' % label_type)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
def _create_fake_sentencepiece_model(output_dir):
vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
model_prefix = os.path.join(output_dir, 'spm_model')
input_text_file_path = os.path.join(output_dir, 'train_input.txt')
with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
f.write(' '.join(vocab + ['\n']))
# Add 7 more tokens: <pad>, <unk>, [CLS], [SEP], [MASK], <s>, </s>.
full_vocab_size = len(vocab) + 7
flags = dict(
model_prefix=model_prefix,
model_type='word',
input=input_text_file_path,
pad_id=0,
unk_id=1,
control_symbols='[CLS],[SEP],[MASK]',
vocab_size=full_vocab_size,
bos_id=full_vocab_size - 2,
eos_id=full_vocab_size - 1)
SentencePieceTrainer.Train(' '.join(
['--{}={}'.format(k, v) for k, v in flags.items()]))
return model_prefix + '.model'
def _create_fake_vocab_file(vocab_file_path):
tokens = ['[PAD]']
for i in range(1, 100):
tokens.append('[unused%d]' % i)
tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
outfile.write('\n'.join(tokens))
class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('int', tf.int32), ('float', tf.float32))
def test_load_dataset(self, label_type, expected_label_type):
input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
batch_size = 10
seq_length = 128
_create_fake_preprocessed_dataset(input_path, seq_length, label_type)
data_config = loader.SentencePredictionDataConfig(
input_path=input_path,
seq_length=seq_length,
global_batch_size=batch_size,
label_type=label_type)
dataset = loader.SentencePredictionDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
self.assertEqual(labels.dtype, expected_label_type)
class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
parameterized.TestCase):
@parameterized.parameters(True, False)
def test_python_wordpiece_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
lower_case = True
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_create_fake_vocab_file(vocab_file_path)
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
lower_case=lower_case,
vocab_file=vocab_file_path)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
@parameterized.parameters(True, False)
def test_python_sentencepiece_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
lower_case = True
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
lower_case=lower_case,
tokenization='SentencePiece',
vocab_file=sp_model_file_path,
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
@parameterized.parameters(True, False)
def test_saved_model_preprocessing(self, use_tfds):
batch_size = 10
seq_length = 256 # Non-default value.
tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
text_fields = ['sentence1', 'sentence2']
if not use_tfds:
_create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_create_fake_vocab_file(vocab_file_path)
data_config = loader.SentencePredictionTextDataConfig(
input_path='' if use_tfds else tf_record_path,
tfds_name='glue/mrpc' if use_tfds else '',
tfds_split='train' if use_tfds else '',
text_fields=text_fields,
global_batch_size=batch_size,
seq_length=seq_length,
is_training=True,
preprocessing_hub_module_url=(
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
label_type='int' if use_tfds else 'float',
)
dataset = loader.SentencePredictionTextDataLoader(data_config).load()
features, labels = next(iter(dataset))
self.assertCountEqual(['input_word_ids', 'input_type_ids', 'input_mask'],
features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size,))
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""BERT library to process data for cross lingual sentence retrieval task.""" """BERT library to process data for cross lingual sentence retrieval task."""
import os import os
...@@ -25,8 +25,7 @@ class BuccProcessor(classifier_data_lib.DataProcessor): ...@@ -25,8 +25,7 @@ class BuccProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme BUCC data set.""" """Procssor for Xtreme BUCC data set."""
supported_languages = ["de", "fr", "ru", "zh"] supported_languages = ["de", "fr", "ru", "zh"]
def __init__(self, def __init__(self, process_text_fn=tokenization.convert_to_unicode):
process_text_fn=tokenization.convert_to_unicode):
super(BuccProcessor, self).__init__(process_text_fn) super(BuccProcessor, self).__init__(process_text_fn)
self.languages = BuccProcessor.supported_languages self.languages = BuccProcessor.supported_languages
...@@ -50,11 +49,11 @@ class BuccProcessor(classifier_data_lib.DataProcessor): ...@@ -50,11 +49,11 @@ class BuccProcessor(classifier_data_lib.DataProcessor):
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
int_iden = int(line[0].split("-")[1]) example_id = int(line[0].split("-")[1])
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
examples.append( examples.append(
classifier_data_lib.InputExample( classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=int_iden)) guid=guid, text_a=text_a, example_id=example_id))
return examples return examples
...@@ -66,8 +65,7 @@ class TatoebaProcessor(classifier_data_lib.DataProcessor): ...@@ -66,8 +65,7 @@ class TatoebaProcessor(classifier_data_lib.DataProcessor):
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh" "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
] ]
def __init__(self, def __init__(self, process_text_fn=tokenization.convert_to_unicode):
process_text_fn=tokenization.convert_to_unicode):
super(TatoebaProcessor, self).__init__(process_text_fn) super(TatoebaProcessor, self).__init__(process_text_fn)
self.languages = TatoebaProcessor.supported_languages self.languages = TatoebaProcessor.supported_languages
...@@ -88,7 +86,7 @@ class TatoebaProcessor(classifier_data_lib.DataProcessor): ...@@ -88,7 +86,7 @@ class TatoebaProcessor(classifier_data_lib.DataProcessor):
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
examples.append( examples.append(
classifier_data_lib.InputExample( classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=i)) guid=guid, text_a=text_a, example_id=i))
return examples return examples
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,19 +11,15 @@ ...@@ -11,19 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections import collections
import copy import copy
import json import json
import math import math
import os import os
import six import six
from absl import logging from absl import logging
...@@ -40,8 +36,8 @@ class SquadExample(object): ...@@ -40,8 +36,8 @@ class SquadExample(object):
Attributes: Attributes:
qas_id: ID of the question-answer pair. qas_id: ID of the question-answer pair.
question_text: Original text for the question. question_text: Original text for the question.
doc_tokens: The list of tokens in the context obtained by splitting doc_tokens: The list of tokens in the context obtained by splitting on
on whitespace only. whitespace only.
orig_answer_text: Original text for the answer. orig_answer_text: Original text for the answer.
start_position: Starting index of the answer in `doc_tokens`. start_position: Starting index of the answer in `doc_tokens`.
end_position: Ending index of the answer in `doc_tokens`. end_position: Ending index of the answer in `doc_tokens`.
...@@ -96,6 +92,8 @@ class InputFeatures(object): ...@@ -96,6 +92,8 @@ class InputFeatures(object):
input_ids, input_ids,
input_mask, input_mask,
segment_ids, segment_ids,
paragraph_mask=None,
class_index=None,
start_position=None, start_position=None,
end_position=None, end_position=None,
is_impossible=None): is_impossible=None):
...@@ -111,6 +109,8 @@ class InputFeatures(object): ...@@ -111,6 +109,8 @@ class InputFeatures(object):
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
self.paragraph_mask = paragraph_mask
self.class_index = class_index
class FeatureWriter(object): class FeatureWriter(object):
...@@ -138,6 +138,11 @@ class FeatureWriter(object): ...@@ -138,6 +138,11 @@ class FeatureWriter(object):
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position]) features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position]) features["end_positions"] = create_int_feature([feature.end_position])
...@@ -153,11 +158,20 @@ class FeatureWriter(object): ...@@ -153,11 +158,20 @@ class FeatureWriter(object):
self._writer.close() self._writer.close()
def read_squad_examples(input_file, is_training, version_2_with_negative): def read_squad_examples(input_file, is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with tf.io.gfile.GFile(input_file, "r") as reader: with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
def is_whitespace(c): def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True return True
...@@ -209,8 +223,8 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): ...@@ -209,8 +223,8 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
# #
# Note that this means for training mode, every example is NOT # Note that this means for training mode, every example is NOT
# guaranteed to be preserved. # guaranteed to be preserved.
actual_text = " ".join( actual_text = " ".join(doc_tokens[start_position:(end_position +
doc_tokens[start_position:(end_position + 1)]) 1)])
cleaned_answer_text = " ".join( cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text)) tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
...@@ -242,6 +256,7 @@ def convert_examples_to_features(examples, ...@@ -242,6 +256,7 @@ def convert_examples_to_features(examples,
max_query_length, max_query_length,
is_training, is_training,
output_fn, output_fn,
xlnet_format=False,
batch_size=None): batch_size=None):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
...@@ -303,25 +318,54 @@ def convert_examples_to_features(examples, ...@@ -303,25 +318,54 @@ def convert_examples_to_features(examples,
token_to_orig_map = {} token_to_orig_map = {}
token_is_max_context = {} token_is_max_context = {}
segment_ids = [] segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0) # Paragraph mask used in XLNet.
for token in query_tokens: # 1 represents paragraph and class tokens.
tokens.append(token) # 0 represents query and other special tokens.
segment_ids.append(0) paragraph_mask = []
tokens.append("[SEP]")
segment_ids.append(0) # pylint: disable=cell-var-from-loop
def process_query(seg_q):
for i in range(doc_span.length): for token in query_tokens:
split_token_index = doc_span.start + i tokens.append(token)
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] segment_ids.append(seg_q)
paragraph_mask.append(0)
is_max_context = _check_is_max_context(doc_spans, doc_span_index, tokens.append("[SEP]")
split_token_index) segment_ids.append(seg_q)
token_is_max_context[len(tokens)] = is_max_context paragraph_mask.append(0)
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1) def process_paragraph(seg_p):
tokens.append("[SEP]") for i in range(doc_span.length):
segment_ids.append(1) split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(seg_p)
paragraph_mask.append(1)
tokens.append("[SEP]")
segment_ids.append(seg_p)
paragraph_mask.append(0)
def process_class(seg_class):
class_index = len(segment_ids)
tokens.append("[CLS]")
segment_ids.append(seg_class)
paragraph_mask.append(1)
return class_index
if xlnet_format:
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
process_paragraph(seg_p)
process_query(seg_q)
class_index = process_class(seg_class)
else:
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
class_index = process_class(seg_class)
process_query(seg_q)
process_paragraph(seg_p)
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = tokenizer.convert_tokens_to_ids(tokens)
...@@ -333,35 +377,30 @@ def convert_examples_to_features(examples, ...@@ -333,35 +377,30 @@ def convert_examples_to_features(examples,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
segment_ids.append(0) segment_ids.append(seg_pad)
paragraph_mask.append(0)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length assert len(segment_ids) == max_seq_length
assert len(paragraph_mask) == max_seq_length
start_position = 0
end_position = 0
span_contains_answer = False
start_position = None
end_position = None
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation # For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict. # we throw it out, since there is nothing to predict.
doc_start = doc_span.start doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1 doc_end = doc_span.start + doc_span.length - 1
out_of_span = False span_contains_answer = (tok_start_position >= doc_start and
if not (tok_start_position >= doc_start and tok_end_position <= doc_end)
tok_end_position <= doc_end): if span_contains_answer:
out_of_span = True doc_offset = 0 if xlnet_format else len(query_tokens) + 2
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20: if example_index < 20:
logging.info("*** Example ***") logging.info("*** Example ***")
logging.info("unique_id: %s", (unique_id)) logging.info("unique_id: %s", (unique_id))
...@@ -381,19 +420,25 @@ def convert_examples_to_features(examples, ...@@ -381,19 +420,25 @@ def convert_examples_to_features(examples,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
if is_training and example.is_impossible: logging.info("paragraph_mask: %s", " ".join(
logging.info("impossible example") [str(x) for x in paragraph_mask]))
if is_training and not example.is_impossible: logging.info("class_index: %d", class_index)
answer_text = " ".join(tokens[start_position:(end_position + 1)]) if is_training:
logging.info("start_position: %d", (start_position)) if span_contains_answer:
logging.info("end_position: %d", (end_position)) answer_text = " ".join(tokens[start_position:(end_position + 1)])
logging.info("answer: %s", tokenization.printable_text(answer_text)) logging.info("start_position: %d", (start_position))
logging.info("end_position: %d", (end_position))
logging.info("answer: %s", tokenization.printable_text(answer_text))
else:
logging.info("document span doesn't contain answer")
feature = InputFeatures( feature = InputFeatures(
unique_id=unique_id, unique_id=unique_id,
example_index=example_index, example_index=example_index,
doc_span_index=doc_span_index, doc_span_index=doc_span_index,
tokens=tokens, tokens=tokens,
paragraph_mask=paragraph_mask,
class_index=class_index,
token_to_orig_map=token_to_orig_map, token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context, token_is_max_context=token_is_max_context,
input_ids=input_ids, input_ids=input_ids,
...@@ -401,7 +446,7 @@ def convert_examples_to_features(examples, ...@@ -401,7 +446,7 @@ def convert_examples_to_features(examples,
segment_ids=segment_ids, segment_ids=segment_ids,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=example.is_impossible) is_impossible=not span_contains_answer)
# Run callback # Run callback
if is_training: if is_training:
...@@ -520,15 +565,16 @@ def write_predictions(all_examples, ...@@ -520,15 +565,16 @@ def write_predictions(all_examples,
logging.info("Writing nbest to: %s", (output_nbest_file)) logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples, postprocess_output(
all_features=all_features, all_examples=all_examples,
all_results=all_results, all_features=all_features,
n_best_size=n_best_size, all_results=all_results,
max_answer_length=max_answer_length, n_best_size=n_best_size,
do_lower_case=do_lower_case, max_answer_length=max_answer_length,
version_2_with_negative=version_2_with_negative, do_lower_case=do_lower_case,
null_score_diff_threshold=null_score_diff_threshold, version_2_with_negative=version_2_with_negative,
verbose=verbose)) null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file) write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file) write_to_json_files(all_nbest_json, output_nbest_file)
...@@ -544,6 +590,7 @@ def postprocess_output(all_examples, ...@@ -544,6 +590,7 @@ def postprocess_output(all_examples,
do_lower_case, do_lower_case,
version_2_with_negative=False, version_2_with_negative=False,
null_score_diff_threshold=0.0, null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False): verbose=False):
"""Postprocess model output, to form predicton results.""" """Postprocess model output, to form predicton results."""
...@@ -572,46 +619,54 @@ def postprocess_output(all_examples, ...@@ -572,46 +619,54 @@ def postprocess_output(all_examples,
null_start_logit = 0 # the start logit at the slice with min null score null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
if feature.unique_id not in unique_id_to_result:
logging.info("Skip eval example %s, not in pred.", feature.unique_id)
continue
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant # if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative: if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0] if xlnet_format:
feature_null_score = result.class_logits
else:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null: if feature_null_score < score_null:
score_null = feature_null_score score_null = feature_null_score
min_null_feature_index = feature_index min_null_feature_index = feature_index
null_start_logit = result.start_logits[0] null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0] null_end_logit = result.end_logits[0]
for start_index in start_indexes: for (start_index, start_logit,
for end_index in end_indexes: end_index, end_logit) in _get_best_indexes_and_logits(
# We could hypothetically create invalid predictions, e.g., predict result=result,
# that the start of the span is in the question. We throw out all n_best_size=n_best_size,
# invalid predictions. xlnet_format=xlnet_format):
if start_index >= len(feature.tokens): # We could hypothetically create invalid predictions, e.g., predict
continue # that the start of the span is in the question. We throw out all
if end_index >= len(feature.tokens): # invalid predictions.
continue if start_index >= len(feature.tokens):
if start_index not in feature.token_to_orig_map: continue
continue if end_index >= len(feature.tokens):
if end_index not in feature.token_to_orig_map: continue
continue if start_index not in feature.token_to_orig_map:
if not feature.token_is_max_context.get(start_index, False): continue
continue if end_index not in feature.token_to_orig_map:
if end_index < start_index: continue
continue if not feature.token_is_max_context.get(start_index, False):
length = end_index - start_index + 1 continue
if length > max_answer_length: if end_index < start_index:
continue continue
prelim_predictions.append( length = end_index - start_index + 1
_PrelimPrediction( if length > max_answer_length:
feature_index=feature_index, continue
start_index=start_index, prelim_predictions.append(
end_index=end_index, _PrelimPrediction(
start_logit=result.start_logits[start_index], feature_index=feature_index,
end_logit=result.end_logits[end_index])) start_index=start_index,
end_index=end_index,
if version_2_with_negative: start_logit=start_logit,
end_logit=end_logit))
if version_2_with_negative and not xlnet_format:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
feature_index=min_null_feature_index, feature_index=min_null_feature_index,
...@@ -633,7 +688,7 @@ def postprocess_output(all_examples, ...@@ -633,7 +688,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
...@@ -666,7 +721,7 @@ def postprocess_output(all_examples, ...@@ -666,7 +721,7 @@ def postprocess_output(all_examples,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it # if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative: if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
...@@ -707,13 +762,18 @@ def postprocess_output(all_examples, ...@@ -707,13 +762,18 @@ def postprocess_output(all_examples,
# pytype: disable=attribute-error # pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
if best_non_null_entry is not None: if best_non_null_entry is not None:
score_diff = score_null - best_non_null_entry.start_logit - ( if xlnet_format:
best_non_null_entry.end_logit) score_diff = score_null
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
else:
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
else: else:
logging.warning("best_non_null_entry is None") logging.warning("best_non_null_entry is None")
scores_diff_json[example.qas_id] = score_null scores_diff_json[example.qas_id] = score_null
...@@ -825,16 +885,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False): ...@@ -825,16 +885,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
return output_text return output_text
def _get_best_indexes(logits, n_best_size): def _get_best_indexes_and_logits(result,
"""Get the n-best logits from a list.""" n_best_size,
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) xlnet_format=False):
"""Generates the n-best indexes and logits from a list."""
best_indexes = [] if xlnet_format:
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate for i in range(n_best_size):
if i >= n_best_size: for j in range(n_best_size):
break j_index = i * n_best_size + j
best_indexes.append(index_and_score[i][0]) yield (result.start_indexes[i], result.start_logits[i],
return best_indexes result.end_indexes[j_index], result.end_logits[j_index])
else:
start_index_and_score = sorted(enumerate(result.start_logits),
key=lambda x: x[1], reverse=True)
end_index_and_score = sorted(enumerate(result.end_logits),
key=lambda x: x[1], reverse=True)
for i in range(len(start_index_and_score)):
if i >= n_best_size:
break
for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores): def _compute_softmax(scores):
...@@ -863,16 +936,19 @@ def _compute_softmax(scores): ...@@ -863,16 +936,19 @@ def _compute_softmax(scores):
def generate_tf_record_from_json_file(input_file_path, def generate_tf_record_from_json_file(input_file_path,
vocab_file_path, vocab_file_path,
output_path, output_path,
translated_input_folder=None,
max_seq_length=384, max_seq_length=384,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
doc_stride=128, doc_stride=128,
version_2_with_negative=False): version_2_with_negative=False,
xlnet_format=False):
"""Generates and saves training data into a tf record file.""" """Generates and saves training data into a tf record file."""
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=input_file_path, input_file=input_file_path,
is_training=True, is_training=True,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_file_path, do_lower_case=do_lower_case) vocab_file=vocab_file_path, do_lower_case=do_lower_case)
train_writer = FeatureWriter(filename=output_path, is_training=True) train_writer = FeatureWriter(filename=output_path, is_training=True)
...@@ -883,7 +959,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -883,7 +959,8 @@ def generate_tf_record_from_json_file(input_file_path,
doc_stride=doc_stride, doc_stride=doc_stride,
max_query_length=max_query_length, max_query_length=max_query_length,
is_training=True, is_training=True,
output_fn=train_writer.process_feature) output_fn=train_writer.process_feature,
xlnet_format=xlnet_format)
train_writer.close() train_writer.close()
meta_data = { meta_data = {
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,22 +11,19 @@ ...@@ -11,22 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization. """Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
The file is forked from: The file is forked from:
https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections import collections
import copy import copy
import json import json
import math import math
import os import os
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -89,6 +86,8 @@ class InputFeatures(object): ...@@ -89,6 +86,8 @@ class InputFeatures(object):
input_mask, input_mask,
segment_ids, segment_ids,
paragraph_len, paragraph_len,
class_index=None,
paragraph_mask=None,
start_position=None, start_position=None,
end_position=None, end_position=None,
is_impossible=None): is_impossible=None):
...@@ -101,19 +100,31 @@ class InputFeatures(object): ...@@ -101,19 +100,31 @@ class InputFeatures(object):
self.tokens = tokens self.tokens = tokens
self.input_ids = input_ids self.input_ids = input_ids
self.input_mask = input_mask self.input_mask = input_mask
self.paragraph_mask = paragraph_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.paragraph_len = paragraph_len self.paragraph_len = paragraph_len
self.class_index = class_index
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training, version_2_with_negative): def read_squad_examples(input_file,
is_training,
version_2_with_negative,
translated_input_folder=None):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
del version_2_with_negative del version_2_with_negative
with tf.io.gfile.GFile(input_file, "r") as reader: with tf.io.gfile.GFile(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
if translated_input_folder is not None:
translated_files = tf.io.gfile.glob(
os.path.join(translated_input_folder, "*.json"))
for file in translated_files:
with tf.io.gfile.GFile(file, "r") as reader:
input_data.extend(json.load(reader)["data"])
examples = [] examples = []
for entry in input_data: for entry in input_data:
for paragraph in entry["paragraphs"]: for paragraph in entry["paragraphs"]:
...@@ -197,6 +208,7 @@ def convert_examples_to_features(examples, ...@@ -197,6 +208,7 @@ def convert_examples_to_features(examples,
is_training, is_training,
output_fn, output_fn,
do_lower_case, do_lower_case,
xlnet_format=False,
batch_size=None): batch_size=None):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
cnt_pos, cnt_neg = 0, 0 cnt_pos, cnt_neg = 0, 0
...@@ -246,6 +258,7 @@ def convert_examples_to_features(examples, ...@@ -246,6 +258,7 @@ def convert_examples_to_features(examples,
f = np.zeros((max_n, max_m), dtype=np.float32) f = np.zeros((max_n, max_m), dtype=np.float32)
g = {} g = {}
# pylint: disable=cell-var-from-loop # pylint: disable=cell-var-from-loop
def _lcs_match(max_dist, n=n, m=m): def _lcs_match(max_dist, n=n, m=m):
"""Longest-common-substring algorithm.""" """Longest-common-substring algorithm."""
...@@ -277,6 +290,7 @@ def convert_examples_to_features(examples, ...@@ -277,6 +290,7 @@ def convert_examples_to_features(examples,
remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]): remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
g[(i, j)] = 2 g[(i, j)] = 2
f[i, j] = f_prev + 1 f[i, j] = f_prev + 1
# pylint: enable=cell-var-from-loop # pylint: enable=cell-var-from-loop
max_dist = abs(n - m) + 5 max_dist = abs(n - m) + 5
...@@ -354,6 +368,7 @@ def convert_examples_to_features(examples, ...@@ -354,6 +368,7 @@ def convert_examples_to_features(examples,
"DocSpan", ["start", "length"]) "DocSpan", ["start", "length"])
doc_spans = [] doc_spans = []
start_offset = 0 start_offset = 0
while start_offset < len(all_doc_tokens): while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc: if length > max_tokens_for_doc:
...@@ -368,34 +383,62 @@ def convert_examples_to_features(examples, ...@@ -368,34 +383,62 @@ def convert_examples_to_features(examples,
token_is_max_context = {} token_is_max_context = {}
segment_ids = [] segment_ids = []
# Paragraph mask used in XLNet.
# 1 represents paragraph and class tokens.
# 0 represents query and other special tokens.
paragraph_mask = []
cur_tok_start_to_orig_index = [] cur_tok_start_to_orig_index = []
cur_tok_end_to_orig_index = [] cur_tok_end_to_orig_index = []
tokens.append(tokenizer.sp_model.PieceToId("[CLS]")) # pylint: disable=cell-var-from-loop
segment_ids.append(0) def process_query(seg_q):
for token in query_tokens: for token in query_tokens:
tokens.append(token) tokens.append(token)
segment_ids.append(0) segment_ids.append(seg_q)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]")) paragraph_mask.append(0)
segment_ids.append(0) tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(seg_q)
for i in range(doc_span.length): paragraph_mask.append(0)
split_token_index = doc_span.start + i
def process_paragraph(seg_p):
cur_tok_start_to_orig_index.append( for i in range(doc_span.length):
tok_start_to_orig_index[split_token_index]) split_token_index = doc_span.start + i
cur_tok_end_to_orig_index.append(
tok_end_to_orig_index[split_token_index]) cur_tok_start_to_orig_index.append(
tok_start_to_orig_index[split_token_index])
is_max_context = _check_is_max_context(doc_spans, doc_span_index, cur_tok_end_to_orig_index.append(
split_token_index) tok_end_to_orig_index[split_token_index])
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) is_max_context = _check_is_max_context(doc_spans, doc_span_index,
segment_ids.append(1) split_token_index)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]")) token_is_max_context[len(tokens)] = is_max_context
segment_ids.append(1) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(seg_p)
paragraph_len = len(tokens) paragraph_mask.append(1)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(seg_p)
paragraph_mask.append(0)
return len(tokens)
def process_class(seg_class):
class_index = len(segment_ids)
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
segment_ids.append(seg_class)
paragraph_mask.append(1)
return class_index
if xlnet_format:
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
paragraph_len = process_paragraph(seg_p)
process_query(seg_q)
class_index = process_class(seg_class)
else:
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
class_index = process_class(seg_class)
process_query(seg_q)
paragraph_len = process_paragraph(seg_p)
input_ids = tokens input_ids = tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
...@@ -406,11 +449,13 @@ def convert_examples_to_features(examples, ...@@ -406,11 +449,13 @@ def convert_examples_to_features(examples,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
segment_ids.append(0) segment_ids.append(seg_pad)
paragraph_mask.append(0)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length assert len(segment_ids) == max_seq_length
assert len(paragraph_mask) == max_seq_length
span_is_impossible = example.is_impossible span_is_impossible = example.is_impossible
start_position = None start_position = None
...@@ -430,13 +475,13 @@ def convert_examples_to_features(examples, ...@@ -430,13 +475,13 @@ def convert_examples_to_features(examples,
end_position = 0 end_position = 0
span_is_impossible = True span_is_impossible = True
else: else:
doc_offset = len(query_tokens) + 2 doc_offset = 0 if xlnet_format else len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
if is_training and span_is_impossible: if is_training and span_is_impossible:
start_position = 0 start_position = class_index
end_position = 0 end_position = class_index
if example_index < 20: if example_index < 20:
logging.info("*** Example ***") logging.info("*** Example ***")
...@@ -456,6 +501,9 @@ def convert_examples_to_features(examples, ...@@ -456,6 +501,9 @@ def convert_examples_to_features(examples,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("paragraph_mask: %s", " ".join(
[str(x) for x in paragraph_mask]))
logging.info("class_index: %d", class_index)
if is_training and span_is_impossible: if is_training and span_is_impossible:
logging.info("impossible example span") logging.info("impossible example span")
...@@ -489,8 +537,10 @@ def convert_examples_to_features(examples, ...@@ -489,8 +537,10 @@ def convert_examples_to_features(examples,
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens], tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
input_ids=input_ids, input_ids=input_ids,
input_mask=input_mask, input_mask=input_mask,
paragraph_mask=paragraph_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
paragraph_len=paragraph_len, paragraph_len=paragraph_len,
class_index=class_index,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=span_is_impossible) is_impossible=span_is_impossible)
...@@ -580,15 +630,16 @@ def write_predictions(all_examples, ...@@ -580,15 +630,16 @@ def write_predictions(all_examples,
logging.info("Writing nbest to: %s", (output_nbest_file)) logging.info("Writing nbest to: %s", (output_nbest_file))
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
postprocess_output(all_examples=all_examples, postprocess_output(
all_features=all_features, all_examples=all_examples,
all_results=all_results, all_features=all_features,
n_best_size=n_best_size, all_results=all_results,
max_answer_length=max_answer_length, n_best_size=n_best_size,
do_lower_case=do_lower_case, max_answer_length=max_answer_length,
version_2_with_negative=version_2_with_negative, do_lower_case=do_lower_case,
null_score_diff_threshold=null_score_diff_threshold, version_2_with_negative=version_2_with_negative,
verbose=verbose)) null_score_diff_threshold=null_score_diff_threshold,
verbose=verbose))
write_to_json_files(all_predictions, output_prediction_file) write_to_json_files(all_predictions, output_prediction_file)
write_to_json_files(all_nbest_json, output_nbest_file) write_to_json_files(all_nbest_json, output_nbest_file)
...@@ -604,11 +655,11 @@ def postprocess_output(all_examples, ...@@ -604,11 +655,11 @@ def postprocess_output(all_examples,
do_lower_case, do_lower_case,
version_2_with_negative=False, version_2_with_negative=False,
null_score_diff_threshold=0.0, null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False): verbose=False):
"""Postprocess model output, to form predicton results.""" """Postprocess model output, to form predicton results."""
del do_lower_case, verbose del do_lower_case, verbose
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -635,47 +686,53 @@ def postprocess_output(all_examples, ...@@ -635,47 +686,53 @@ def postprocess_output(all_examples,
null_start_logit = 0 # the start logit at the slice with min null score null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
if feature.unique_id not in unique_id_to_result:
logging.info("Skip eval example %s, not in pred.", feature.unique_id)
continue
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant # if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative: if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0] if xlnet_format:
feature_null_score = result.class_logits
else:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null: if feature_null_score < score_null:
score_null = feature_null_score score_null = feature_null_score
min_null_feature_index = feature_index min_null_feature_index = feature_index
null_start_logit = result.start_logits[0] null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0] null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes: doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
doc_offset = feature.tokens.index("[SEP]") + 1
# We could hypothetically create invalid predictions, e.g., predict for (start_index, start_logit,
# that the start of the span is in the question. We throw out all end_index, end_logit) in _get_best_indexes_and_logits(
# invalid predictions. result=result,
if start_index - doc_offset >= len(feature.tok_start_to_orig_index): n_best_size=n_best_size,
continue xlnet_format=xlnet_format):
if end_index - doc_offset >= len(feature.tok_end_to_orig_index): # We could hypothetically create invalid predictions, e.g., predict
continue # that the start of the span is in the question. We throw out all
# if start_index not in feature.tok_start_to_orig_index: # invalid predictions.
# continue if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
# if end_index not in feature.tok_end_to_orig_index: continue
# continue if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
if not feature.token_is_max_context.get(start_index, False): continue
continue if not feature.token_is_max_context.get(start_index, False):
if end_index < start_index: continue
continue if end_index < start_index:
length = end_index - start_index + 1 continue
if length > max_answer_length: length = end_index - start_index + 1
continue if length > max_answer_length:
prelim_predictions.append( continue
_PrelimPrediction( prelim_predictions.append(
feature_index=feature_index, _PrelimPrediction(
start_index=start_index - doc_offset, feature_index=feature_index,
end_index=end_index - doc_offset, start_index=start_index - doc_offset,
start_logit=result.start_logits[start_index], end_index=end_index - doc_offset,
end_logit=result.end_logits[end_index])) start_logit=start_logit,
end_logit=end_logit))
if version_2_with_negative:
if version_2_with_negative and not xlnet_format:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
feature_index=min_null_feature_index, feature_index=min_null_feature_index,
...@@ -697,7 +754,7 @@ def postprocess_output(all_examples, ...@@ -697,7 +754,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index >= 0: # this is a non-null prediction if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
tok_start_to_orig_index = feature.tok_start_to_orig_index tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index] start_orig_pos = tok_start_to_orig_index[pred.start_index]
...@@ -719,8 +776,8 @@ def postprocess_output(all_examples, ...@@ -719,8 +776,8 @@ def postprocess_output(all_examples,
start_logit=pred.start_logit, start_logit=pred.start_logit,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it # if we didn't inlude the empty option in the n-best, include it
if version_2_with_negative: if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
...@@ -759,14 +816,19 @@ def postprocess_output(all_examples, ...@@ -759,14 +816,19 @@ def postprocess_output(all_examples,
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else: else:
assert best_non_null_entry is not None assert best_non_null_entry is not None
# predict "" iff the null score - the score of best non-null > threshold if xlnet_format:
score_diff = score_null - best_non_null_entry.start_logit - ( score_diff = score_null
best_non_null_entry.end_logit) scores_diff_json[example.qas_id] = score_diff
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
...@@ -778,16 +840,29 @@ def write_to_json_files(json_records, json_file): ...@@ -778,16 +840,29 @@ def write_to_json_files(json_records, json_file):
writer.write(json.dumps(json_records, indent=4) + "\n") writer.write(json.dumps(json_records, indent=4) + "\n")
def _get_best_indexes(logits, n_best_size): def _get_best_indexes_and_logits(result,
"""Get the n-best logits from a list.""" n_best_size,
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) xlnet_format=False):
"""Generates the n-best indexes and logits from a list."""
best_indexes = [] if xlnet_format:
for i in range(len(index_and_score)): for i in range(n_best_size):
if i >= n_best_size: for j in range(n_best_size):
break j_index = i * n_best_size + j
best_indexes.append(index_and_score[i][0]) yield (result.start_indexes[i], result.start_logits[i],
return best_indexes result.end_indexes[j_index], result.end_logits[j_index])
else:
start_index_and_score = sorted(enumerate(result.start_logits),
key=lambda x: x[1], reverse=True)
end_index_and_score = sorted(enumerate(result.end_logits),
key=lambda x: x[1], reverse=True)
for i in range(len(start_index_and_score)):
if i >= n_best_size:
break
for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores): def _compute_softmax(scores):
...@@ -837,6 +912,10 @@ class FeatureWriter(object): ...@@ -837,6 +912,10 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids) features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position]) features["start_positions"] = create_int_feature([feature.start_position])
...@@ -856,19 +935,23 @@ class FeatureWriter(object): ...@@ -856,19 +935,23 @@ class FeatureWriter(object):
def generate_tf_record_from_json_file(input_file_path, def generate_tf_record_from_json_file(input_file_path,
sp_model_file, sp_model_file,
output_path, output_path,
translated_input_folder=None,
max_seq_length=384, max_seq_length=384,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
doc_stride=128, doc_stride=128,
xlnet_format=False,
version_2_with_negative=False): version_2_with_negative=False):
"""Generates and saves training data into a tf record file.""" """Generates and saves training data into a tf record file."""
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=input_file_path, input_file=input_file_path,
is_training=True, is_training=True,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative,
translated_input_folder=translated_input_folder)
tokenizer = tokenization.FullSentencePieceTokenizer( tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file) sp_model_file=sp_model_file)
train_writer = FeatureWriter(filename=output_path, is_training=True) train_writer = FeatureWriter(
filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features( number_of_examples = convert_examples_to_features(
examples=train_examples, examples=train_examples,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -877,6 +960,7 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -877,6 +960,7 @@ def generate_tf_record_from_json_file(input_file_path,
max_query_length=max_query_length, max_query_length=max_query_length,
is_training=True, is_training=True,
output_fn=train_writer.process_feature, output_fn=train_writer.process_feature,
xlnet_format=xlnet_format,
do_lower_case=do_lower_case) do_lower_case=do_lower_case)
train_writer.close() train_writer.close()
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS.""" """Library to process data for tagging task such as NER/POS."""
import collections import collections
import os import os
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute # A negative label id for the padding label, which will not contribute
...@@ -33,9 +34,14 @@ _UNK_TOKEN = "[UNK]" ...@@ -33,9 +34,14 @@ _UNK_TOKEN = "[UNK]"
class InputExample(object): class InputExample(object):
"""A single training/test example for token classification.""" """A single training/test example for token classification."""
def __init__(self, sentence_id, words=None, label_ids=None): def __init__(self,
sentence_id,
sub_sentence_id=0,
words=None,
label_ids=None):
"""Constructs an InputExample.""" """Constructs an InputExample."""
self.sentence_id = sentence_id self.sentence_id = sentence_id
self.sub_sentence_id = sub_sentence_id
self.words = words if words else [] self.words = words if words else []
self.label_ids = label_ids if label_ids else [] self.label_ids = label_ids if label_ids else []
...@@ -84,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor): ...@@ -84,13 +90,48 @@ class PanxProcessor(classifier_data_lib.DataProcessor):
"tr", "et", "fi", "hu" "tr", "et", "fi", "hu"
] ]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Args:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(PanxProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
return _read_one_file( examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels()) os.path.join(data_dir, "train-en.tsv"), self.get_labels())
if not self.only_use_en_train:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"train-{language}.tsv"),
self.get_labels()))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
return _read_one_file( examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels()) os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
if not self.only_use_en_dev:
for language in self.supported_languages:
if language == "en":
continue
examples.extend(
_read_one_file(
os.path.join(data_dir, f"dev-{language}.tsv"),
self.get_labels()))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
examples_dict = {} examples_dict = {}
...@@ -115,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor): ...@@ -115,13 +156,49 @@ class UdposProcessor(classifier_data_lib.DataProcessor):
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh" "ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
] ]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
only_use_en_train=True,
only_use_en_dev=True):
"""See base class.
Args:
process_text_fn: See base class.
only_use_en_train: If True, only use english training data. Otherwise, use
training data from all languages.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(UdposProcessor, self).__init__(process_text_fn)
self.only_use_en_train = only_use_en_train
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
return _read_one_file( if self.only_use_en_train:
os.path.join(data_dir, "train-en.tsv"), self.get_labels()) examples = _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
else:
examples = []
# Uses glob because some languages are missing in train.
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "train-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
return _read_one_file( if self.only_use_en_dev:
os.path.join(data_dir, "dev-en.tsv"), self.get_labels()) examples = _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
else:
examples = []
for filepath in tf.io.gfile.glob(os.path.join(data_dir, "dev-*.tsv")):
examples.extend(
_read_one_file(
filepath,
self.get_labels()))
return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
examples_dict = {} examples_dict = {}
...@@ -146,11 +223,11 @@ def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None): ...@@ -146,11 +223,11 @@ def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
# Needs additional [CLS] and [SEP] tokens. # Needs additional [CLS] and [SEP] tokens.
max_length = max_length - 2 max_length = max_length - 2
new_examples = [] new_examples = []
new_example = InputExample(sentence_id=example.sentence_id) new_example = InputExample(sentence_id=example.sentence_id, sub_sentence_id=0)
for i, word in enumerate(example.words): if any([x < 0 for x in example.label_ids]):
if any([x < 0 for x in example.label_ids]): raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
for i, word in enumerate(example.words):
if text_preprocessing: if text_preprocessing:
word = text_preprocessing(word) word = text_preprocessing(word)
subwords = tokenizer.tokenize(word) subwords = tokenizer.tokenize(word)
...@@ -160,7 +237,10 @@ def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None): ...@@ -160,7 +237,10 @@ def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
if len(subwords) + len(new_example.words) > max_length: if len(subwords) + len(new_example.words) > max_length:
# Start a new example. # Start a new example.
new_examples.append(new_example) new_examples.append(new_example)
new_example = InputExample(sentence_id=example.sentence_id) last_sub_sentence_id = new_example.sub_sentence_id
new_example = InputExample(
sentence_id=example.sentence_id,
sub_sentence_id=last_sub_sentence_id + 1)
for j, subword in enumerate(subwords): for j, subword in enumerate(subwords):
# Use the real label for the first subword, and pad label for # Use the real label for the first subword, and pad label for
...@@ -203,6 +283,7 @@ def _convert_single_example(example, max_seq_length, tokenizer): ...@@ -203,6 +283,7 @@ def _convert_single_example(example, max_seq_length, tokenizer):
features["segment_ids"] = create_int_feature(segment_ids) features["segment_ids"] = create_int_feature(segment_ids)
features["label_ids"] = create_int_feature(label_ids) features["label_ids"] = create_int_feature(label_ids)
features["sentence_id"] = create_int_feature([example.sentence_id]) features["sentence_id"] = create_int_feature([example.sentence_id])
features["sub_sentence_id"] = create_int_feature([example.sub_sentence_id])
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) tf_example = tf.train.Example(features=tf.train.Features(feature=features))
return tf_example return tf_example
...@@ -267,12 +348,12 @@ def write_example_to_file(examples, ...@@ -267,12 +348,12 @@ def write_example_to_file(examples,
logging.info("Writing example %d of %d to %s", ex_index, len(examples), logging.info("Writing example %d of %d to %s", ex_index, len(examples),
output_file) output_file)
tokenized_examples = _tokenize_example(example, max_seq_length, tokenized_examples = _tokenize_example(example, max_seq_length, tokenizer,
tokenizer, text_preprocessing) text_preprocessing)
num_tokenized_examples += len(tokenized_examples) num_tokenized_examples += len(tokenized_examples)
for per_tokenized_example in tokenized_examples: for per_tokenized_example in tokenized_examples:
tf_example = _convert_single_example( tf_example = _convert_single_example(per_tokenized_example,
per_tokenized_example, max_seq_length, tokenizer) max_seq_length, tokenizer)
writer.write(tf_example.SerializeToString()) writer.write(tf_example.SerializeToString())
writer.close() writer.close()
...@@ -307,17 +388,16 @@ def token_classification_meta_data(train_data_size, ...@@ -307,17 +388,16 @@ def token_classification_meta_data(train_data_size,
return meta_data return meta_data
def generate_tf_record_from_data_file(processor, def generate_tf_record_from_data_file(processor, data_dir, tokenizer,
data_dir, max_seq_length, train_data_output_path,
tokenizer,
max_seq_length,
train_data_output_path,
eval_data_output_path, eval_data_output_path,
test_data_output_path, test_data_output_path,
text_preprocessing): text_preprocessing):
"""Generates tfrecord files from the raw data.""" """Generates tfrecord files from the raw data."""
common_kwargs = dict(tokenizer=tokenizer, max_seq_length=max_seq_length, common_kwargs = dict(
text_preprocessing=text_preprocessing) tokenizer=tokenizer,
max_seq_length=max_seq_length,
text_preprocessing=text_preprocessing)
train_examples = processor.get_train_examples(data_dir) train_examples = processor.get_train_examples(data_dir)
train_data_size = write_example_to_file( train_data_size = write_example_to_file(
train_examples, output_file=train_data_output_path, **common_kwargs) train_examples, output_file=train_data_output_path, **common_kwargs)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.tagging_data_lib."""
import os
import random
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import tagging_data_lib
def _create_fake_file(filename, labels, is_test):
def write_one_sentence(writer, length):
for _ in range(length):
line = "hiworld"
if not is_test:
line += "\t%s" % (labels[random.randint(0, len(labels) - 1)])
writer.write(line + "\n")
# Writes two sentences with length of 3 and 12 respectively.
with tf.io.gfile.GFile(filename, "w") as writer:
write_one_sentence(writer, 3)
writer.write("\n")
write_one_sentence(writer, 12)
class TaggingDataLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TaggingDataLibTest, self).setUp()
self.processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
self.vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")
with tf.io.gfile.GFile(self.vocab_file, "w") as writer:
writer.write("\n".join(["[CLS]", "[SEP]", "hi", "##world", "[UNK]"]))
@parameterized.parameters(
{"task_type": "panx"},
{"task_type": "udpos"},
)
def test_generate_tf_record(self, task_type):
processor = self.processors[task_type]()
input_data_dir = os.path.join(self.get_temp_dir(), task_type)
tf.io.gfile.mkdir(input_data_dir)
# Write fake train file.
_create_fake_file(
os.path.join(input_data_dir, "train-en.tsv"),
processor.get_labels(),
is_test=False)
# Write fake dev file.
_create_fake_file(
os.path.join(input_data_dir, "dev-en.tsv"),
processor.get_labels(),
is_test=False)
# Write fake test files.
for lang in processor.supported_languages:
_create_fake_file(
os.path.join(input_data_dir, "test-%s.tsv" % lang),
processor.get_labels(),
is_test=True)
output_path = os.path.join(self.get_temp_dir(), task_type, "output")
tokenizer = tokenization.FullTokenizer(
vocab_file=self.vocab_file, do_lower_case=True)
metadata = tagging_data_lib.generate_tf_record_from_data_file(
processor,
input_data_dir,
tokenizer,
max_seq_length=8,
train_data_output_path=os.path.join(output_path, "train.tfrecord"),
eval_data_output_path=os.path.join(output_path, "eval.tfrecord"),
test_data_output_path=os.path.join(output_path, "test_{}.tfrecord"),
text_preprocessing=tokenization.convert_to_unicode)
self.assertEqual(metadata["train_data_size"], 5)
files = tf.io.gfile.glob(output_path + "/*")
expected_files = []
expected_files.append(os.path.join(output_path, "train.tfrecord"))
expected_files.append(os.path.join(output_path, "eval.tfrecord"))
for lang in processor.supported_languages:
expected_files.append(
os.path.join(output_path, "test_%s.tfrecord" % lang))
self.assertCountEqual(files, expected_files)
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
y = record['label_ids']
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
include_sentence_id: bool = False
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader(data_loader.DataLoader):
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
self._include_sentence_id = params.include_sentence_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
if self._include_sentence_id:
name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['sub_sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_sentence_id:
x['sentence_id'] = record['sentence_id']
x['sub_sentence_id'] = record['sub_sentence_id']
y = record['label_ids']
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.tagging_data_loader."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.nlp.data import tagging_dataloader
def _create_fake_dataset(output_path, seq_length, include_sentence_id):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
for i in range(100):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features['input_ids'] = create_int_feature(input_ids)
features['input_mask'] = create_int_feature(np.ones_like(input_ids))
features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
features['label_ids'] = create_int_feature(
np.random.randint(10, size=(seq_length)))
if include_sentence_id:
features['sentence_id'] = create_int_feature([i])
features['sub_sentence_id'] = create_int_feature([0])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
class TaggingDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(True, False)
def test_load_dataset(self, include_sentence_id):
seq_length = 16
batch_size = 10
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(train_data_path, seq_length, include_sentence_id)
data_config = tagging_dataloader.TaggingDataConfig(
input_path=train_data_path,
seq_length=seq_length,
global_batch_size=batch_size,
include_sentence_id=include_sentence_id)
dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
features, labels = next(iter(dataset))
expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids']
if include_sentence_id:
expected_keys.extend(['sentence_id', 'sub_sentence_id'])
self.assertCountEqual(expected_keys, features.keys())
self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(labels.shape, (batch_size, seq_length))
if include_sentence_id:
self.assertEqual(features['sentence_id'].shape, (batch_size,))
self.assertEqual(features['sub_sentence_id'].shape, (batch_size,))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment