Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence"),
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,71 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
from typing import List
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
class ElectraPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
discriminator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
if generator_network is None:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
last_hidden_dim=config.generator_encoder.hidden_size,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,22 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer Encoders.
Includes configurations and instantiation methods.
Includes configurations and factory methods.
"""
from typing import Optional
from absl import logging
import dataclasses
import gin
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import networks
from official.nlp.projects.bigbird import encoder as bigbird_encoder
@dataclasses.dataclass
class TransformerEncoderConfig(base_config.Config):
class BertEncoderConfig(hyperparams.Config):
"""BERT encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
......@@ -40,56 +43,303 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
@dataclasses.dataclass
class MobileBertEncoderConfig(hyperparams.Config):
"""MobileBERT encoder configuration.
Attributes:
word_vocab_size: number of words in the vocabulary.
word_embed_size: word embedding size.
type_vocab_size: number of word types.
max_sequence_length: maximum length of input sequence.
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
layer.
hidden_activation: the non-linear activation function to apply to the
output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: whether to share linear transformation for keys
and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: if using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
word_vocab_size: int = 30522
word_embed_size: int = 128
type_vocab_size: int = 2
max_sequence_length: int = 512
num_blocks: int = 24
hidden_size: int = 512
num_attention_heads: int = 4
intermediate_size: int = 4096
hidden_activation: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
intra_bottleneck_size: int = 1024
initializer_range: float = 0.02
use_bottleneck_attention: bool = False
key_query_shared_bottleneck: bool = False
num_feedforward_networks: int = 1
normalization_type: str = "layer_norm"
classifier_activation: bool = True
input_mask_dtype: str = "int32"
@dataclasses.dataclass
class AlbertEncoderConfig(hyperparams.Config):
"""ALBERT encoder configuration."""
vocab_size: int = 30000
embedding_width: int = 128
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.0
attention_dropout_rate: float = 0.0
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
@dataclasses.dataclass
class BigBirdEncoderConfig(hyperparams.Config):
"""BigBird encoder configuration."""
vocab_size: int = 50358
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 4096
num_rand_blocks: int = 3
block_size: int = 64
type_vocab_size: int = 16
initializer_range: float = 0.02
embedding_width: Optional[int] = None
use_gradient_checkpointing: bool = False
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
vocab_size: int = 32000
num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16
head_size: int = 64
inner_size: int = 4096
inner_activation: str = "gelu"
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
attention_type: str = "bi"
bi_data: bool = False
tie_attention_biases: bool = False
memory_length: int = 0
same_length: bool = False
clamp_length: int = -1
reuse_length: int = 0
use_cls_mask: bool = False
embedding_width: int = 1024
initializer_range: float = 0.02
two_stream: bool = False
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
ENCODER_CLS = {
"bert": networks.BertEncoder,
"mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertEncoder,
"bigbird": bigbird_encoder.BigBirdEncoder,
"xlnet": networks.XLNetBase,
}
@gin.configurable
def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config:
return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings,
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
sequence_length=None,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
if encoder_type == "mobilebert":
return encoder_cls(
word_vocab_size=encoder_cfg.word_vocab_size,
word_embed_size=encoder_cfg.word_embed_size,
type_vocab_size=encoder_cfg.type_vocab_size,
max_sequence_length=encoder_cfg.max_sequence_length,
num_blocks=encoder_cfg.num_blocks,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_act_fn=encoder_cfg.hidden_activation,
hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
initializer_range=encoder_cfg.initializer_range,
use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type,
classifier_activation=encoder_cfg.classifier_activation,
input_mask_dtype=encoder_cfg.input_mask_dtype)
if encoder_type == "albert":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_width,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dict_outputs=True)
if encoder_type == "bigbird":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_width,
use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing)
if encoder_type == "xlnet":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
head_size=encoder_cfg.head_size,
inner_size=encoder_cfg.inner_size,
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
attention_type=encoder_cfg.attention_type,
bi_data=encoder_cfg.bi_data,
two_stream=encoder_cfg.two_stream,
tie_attention_biases=encoder_cfg.tie_attention_biases,
memory_length=encoder_cfg.memory_length,
clamp_length=encoder_cfg.clamp_length,
reuse_length=encoder_cfg.reuse_length,
inner_activation=encoder_cfg.inner_activation,
use_cls_mask=encoder_cfg.use_cls_mask,
embedding_width=encoder_cfg.embedding_width,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_network
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.configs.encoders."""
import os
import tensorflow as tf
from official.modeling import hyperparams
from official.nlp.configs import encoders
class EncodersTest(tf.test.TestCase):
def test_encoder_from_yaml(self):
config = encoders.EncoderConfig(
type="bert", bert=encoders.BertEncoderConfig(num_layers=1))
encoder = encoders.build_encoder(config)
ckpt = tf.train.Checkpoint(encoder=encoder)
ckpt_path = ckpt.save(self.get_temp_dir() + "/ckpt")
params_save_path = os.path.join(self.get_temp_dir(), "params.yaml")
hyperparams.save_params_dict_to_yaml(config, params_save_path)
retored_cfg = encoders.EncoderConfig.from_yaml(params_save_path)
retored_encoder = encoders.build_encoder(retored_cfg)
status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path)
status.assert_consumed()
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiments definition."""
# pylint: disable=unused-import
from official.nlp.configs import finetuning_experiments
from official.nlp.configs import pretraining_experiments
from official.nlp.configs import wmt_transformer_experiments
task:
hub_module_url: ''
model:
num_classes: 3
init_checkpoint: ''
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: ''
is_training: true
seq_length: 128
label_type: 'int'
validation_data:
drop_remainder: false
global_batch_size: 32
input_path: ''
is_training: false
seq_length: 128
label_type: 'int'
trainer:
checkpoint_interval: 3000
optimizer_config:
learning_rate:
polynomial:
# 100% of train_steps.
decay_steps: 36813
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
# ~10% of train_steps.
warmup_steps: 3681
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
# Training data size 392,702 examples, 3 epochs.
train_steps: 36813
validation_interval: 6135
# Eval data size = 9815 examples.
validation_steps: 307
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'cls_accuracy'
best_checkpoint_metric_comp: 'higher'
task:
hub_module_url: ''
max_answer_length: 30
n_best_size: 20
null_score_diff_threshold: 0.0
init_checkpoint: ''
train_data:
drop_remainder: true
global_batch_size: 48
input_path: ''
is_training: true
seq_length: 384
validation_data:
do_lower_case: true
doc_stride: 128
drop_remainder: false
global_batch_size: 48
input_path: ''
is_training: false
query_length: 64
seq_length: 384
tokenization: WordPiece
version_2_with_negative: false
vocab_file: ''
trainer:
checkpoint_interval: 1000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
decay_steps: 3699
end_learning_rate: 0.0
initial_learning_rate: 8.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 370
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
train_steps: 3699
validation_interval: 1000
validation_steps: 226
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_eval_metric: 'final_f1'
best_checkpoint_metric_comp: 'higher'
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import question_answering_dataloader
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.data import tagging_dataloader
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
@exp_factory.register_config_factory('bert/sentence_prediction')
def bert_sentence_prediction() -> cfg.ExperimentConfig:
r"""BERT GLUE."""
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 3e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
config.task.model.encoder.type = 'bert'
return config
@exp_factory.register_config_factory('bert/squad')
def bert_squad() -> cfg.ExperimentConfig:
"""BERT Squad V1/V2."""
config = cfg.ExperimentConfig(
task=question_answering.QuestionAnsweringConfig(
train_data=question_answering_dataloader.QADataConfig(),
validation_data=question_answering_dataloader.QADataConfig()),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 8e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
config.task.model.encoder.type = 'bert'
return config
@exp_factory.register_config_factory('bert/tagging')
def bert_tagging() -> cfg.ExperimentConfig:
"""BERT tagging task."""
config = cfg.ExperimentConfig(
task=tagging.TaggingConfig(
train_data=tagging_dataloader.TaggingDataConfig(),
validation_data=tagging_dataloader.TaggingDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 8e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
return config
task:
model:
encoder:
type: bert
bert:
attention_dropout_rate: 0.1
dropout_rate: 0.1
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretraining experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import pretrain_dataloader
from official.nlp.data import pretrain_dynamic_dataloader
from official.nlp.tasks import masked_lm
_TRAINER = cfg.TrainerConfig(
train_steps=1000000,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay': [
'LayerNorm', 'layer_norm', 'bias'
],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 1e-4,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
}))
@exp_factory.register_config_factory('bert/pretraining')
def bert_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining experiment."""
config = cfg.ExperimentConfig(
task=masked_lm.MaskedLMConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=_TRAINER,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('bert/pretraining_dynamic')
def bert_dynamic() -> cfg.ExperimentConfig:
"""BERT base with dynamic input sequences.
TPU needs to run with tf.data service with round-robin behavior.
"""
config = cfg.ExperimentConfig(
task=masked_lm.MaskedLMConfig(
train_data=pretrain_dynamic_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=_TRAINER,
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import translation
@exp_factory.register_config_factory('wmt_transformer/large')
def wmt_transformer_large() -> cfg.ExperimentConfig:
"""WMT Transformer Large.
Please refer to
tensorflow_models/official/nlp/data/train_sentencepiece.py
to generate sentencepiece_model
and pass
--params_override=task.sentencepiece_model_path='YOUR_PATH'
to the train script.
"""
learning_rate = 2.0
hidden_size = 1024
learning_rate *= (hidden_size**-0.5)
warmup_steps = 16000
train_steps = 300000
token_batch_size = 24576
encdecoder = translation.EncDecoder(
num_attention_heads=16, intermediate_size=hidden_size * 4)
config = cfg.ExperimentConfig(
task=translation.TranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=hidden_size,
padded_decode=True,
decode_max_length=100),
train_data=wmt_dataloader.WMTDataConfig(
tfds_name='wmt14_translate/de-en',
tfds_split='train',
src_lang='en',
tgt_lang='de',
is_training=True,
global_batch_size=token_batch_size,
static_batch=True,
max_seq_length=64
),
validation_data=wmt_dataloader.WMTDataConfig(
tfds_name='wmt14_translate/de-en',
tfds_split='test',
src_lang='en',
tgt_lang='de',
is_training=False,
global_batch_size=32,
static_batch=True,
max_seq_length=100,
),
sentencepiece_model_path=None,
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
steps_per_loop=1000,
summary_interval=1000,
checkpoint_interval=5000,
validation_interval=5000,
max_to_keep=1,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_2': 0.997,
'epsilon': 1e-9,
},
},
'learning_rate': {
'type': 'power',
'power': {
'initial_learning_rate': learning_rate,
'power': -0.5,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': warmup_steps,
'warmup_learning_rate': 0.0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.sentencepiece_model_path != None',
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM continuous finetuning+eval training driver library."""
import gc
import os
import time
from typing import Any, Mapping, Optional
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib
def _flatten_dict(xs):
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Example::
xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
flat_xs = flatten_dict(xs)
print(flat_xs)
# {
# ('foo',): 1,
# ('bar', 'a'): 2,
# }
Note that empty dictionaries are ignored and
will not be restored by `unflatten_dict`.
Args:
xs: a nested dictionary
Returns:
The flattened dictionary.
"""
assert isinstance(xs, dict), 'input is not a dict'
def _flatten(xs, prefix):
if not isinstance(xs, dict):
return {prefix: xs}
result = {}
for key, value in xs.items():
path = prefix + (key,)
result.update(_flatten(value, path))
return result
return _flatten(xs, ())
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert mode == 'continuous_train_and_eval', (
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'.format(mode))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
retry_times = 0
while not tf.io.gfile.isdir(params.task.init_checkpoint):
# Wait for the init_checkpoint directory to be created.
if retry_times >= 60:
raise ValueError(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.')
retry_times += 1
time.sleep(60)
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
# If there are checkpoints, they might be the finetune checkpoint of a
# different pretrained checkpoint. So we just remove all checkpoints.
train_utils.remove_ckpts(model_dir)
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
# Replaces params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
if params.trainer.best_checkpoint_export_subdir:
best_ckpt_subdir = '{}_{}'.format(
params.trainer.best_checkpoint_export_subdir, global_step)
params_replaced = params.replace(
task={'init_checkpoint': pretrain_ckpt},
trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
else:
params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
params_replaced.lock()
logging.info('Running finetuning with params: %s', params_replaced)
with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks)
(_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=task,
eval_tasks=eval_tasks,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
else:
task = task_factory.get_task(
params_replaced.task, logging_dir=model_dir)
_, eval_metrics = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.name
else:
summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {}
for name, value in _flatten_dict(eval_metrics).items():
summaries[summary_grp + '/' + '-'.join(name)] = value
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc.collect()
if run_post_eval:
return eval_metrics
return {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
def testContinuousFinetune(self):
pretrain_steps = 1
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='continuous_train_and_eval',
model_dir=self._model_dir,
params_override={
'task': {
'init_checkpoint': src_model_dir,
},
'trainer': {
'continuous_eval_timeout': 1,
'steps_per_loop': 1,
'train_steps': 1,
'validation_steps': 1,
'best_checkpoint_export_subdir': 'best_ckpt',
'best_checkpoint_eval_metric': 'acc',
'optimizer_config': {
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
}
}
})
with flagsaver.flagsaver(**flags_dict):
# Train and save some checkpoints.
params = train_utils.parse_configuration(flags.FLAGS)
distribution_strategy = tf.distribute.get_strategy()
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=src_model_dir)
_ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train',
params=params,
model_dir=src_model_dir)
params = train_utils.parse_configuration(FLAGS)
eval_metrics = continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode,
params,
FLAGS.model_dir,
run_post_eval=True,
pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics)
self.assertFalse(
tf.io.gfile.exists(os.path.join(FLAGS.model_dir, 'checkpoint')))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,16 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""BERT library to process data for classification task."""
import collections
import csv
import importlib
import json
import os
from absl import logging
......@@ -39,7 +36,7 @@ class InputExample(object):
text_b=None,
label=None,
weight=None,
int_iden=None):
example_id=None):
"""Constructs a InputExample.
Args:
......@@ -53,15 +50,15 @@ class InputExample(object):
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
example_id: (Optional) int. The int identification number of example in
the corpus.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.weight = weight
self.int_iden = int_iden
self.example_id = example_id
class InputFeatures(object):
......@@ -74,14 +71,14 @@ class InputFeatures(object):
label_id,
is_real_example=True,
weight=None,
int_iden=None):
example_id=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
self.weight = weight
self.int_iden = int_iden
self.example_id = example_id
class DataProcessor(object):
......@@ -123,6 +120,63 @@ class DataProcessor(object):
lines.append(line)
return lines
@classmethod
def _read_jsonl(cls, input_file):
"""Reads a json line file."""
with tf.io.gfile.GFile(input_file, "r") as f:
lines = []
for json_str in f:
lines.append(json.loads(json_str))
return lines
class AxProcessor(DataProcessor):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "AX"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
text_a_index = 1 if set_type == "test" else 8
text_b_index = 2 if set_type == "test" else 9
examples = []
for i, line in enumerate(lines):
# Skip header.
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
text_a = self.process_text_fn(line[text_a_index])
text_b = self.process_text_fn(line[text_b_index])
if set_type == "test":
label = "contradiction"
else:
label = self.process_text_fn(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
......@@ -152,10 +206,10 @@ class ColaProcessor(DataProcessor):
return "COLA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
for i, line in enumerate(lines):
# Only the test set has a header.
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -170,9 +224,55 @@ class ColaProcessor(DataProcessor):
return examples
class ImdbProcessor(DataProcessor):
"""Processor for the IMDb dataset."""
def get_labels(self):
return ["neg", "pos"]
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test"))
@staticmethod
def get_processor_name():
"""See base class."""
return "IMDB"
def _create_examples(self, data_dir):
"""Creates examples."""
examples = []
for label in ["neg", "pos"]:
cur_dir = os.path.join(data_dir, label)
for filename in tf.io.gfile.listdir(cur_dir):
if not filename.endswith("txt"):
continue
if len(examples) % 1000 == 0:
logging.info("Loading dev example %d", len(examples))
path = os.path.join(cur_dir, filename)
with tf.io.gfile.GFile(path, "r") as f:
text = f.read().strip().replace("<br />", " ")
examples.append(
InputExample(
guid="unused_id", text_a=text, text_b=None, label=label))
return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def __init__(self,
mnli_type="matched",
process_text_fn=tokenization.convert_to_unicode):
super(MnliProcessor, self).__init__(process_text_fn)
if mnli_type not in ("matched", "mismatched"):
raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
self.mnli_type = mnli_type
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
......@@ -180,14 +280,23 @@ class MnliProcessor(DataProcessor):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
if self.mnli_type == "matched":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
else:
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
def get_labels(self):
"""See base class."""
......@@ -199,9 +308,9 @@ class MnliProcessor(DataProcessor):
return "MNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
......@@ -244,9 +353,9 @@ class MrpcProcessor(DataProcessor):
return "MRPC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -290,7 +399,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -307,7 +416,7 @@ class PawsxProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -321,7 +430,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2])
......@@ -368,9 +477,9 @@ class QnliProcessor(DataProcessor):
return "QNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
......@@ -415,18 +524,24 @@ class QqpProcessor(DataProcessor):
return "QQP"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
if set_type == "test":
text_a = line[1]
text_b = line[2]
label = "0"
else:
# There appear to be some garbage lines in the train dataset.
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -462,7 +577,7 @@ class RteProcessor(DataProcessor):
return "RTE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -507,9 +622,9 @@ class SstProcessor(DataProcessor):
return "SST-2"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
......@@ -558,7 +673,7 @@ class StsBProcessor(DataProcessor):
return "STS-B"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -671,7 +786,7 @@ class TfdsProcessor(DataProcessor):
return "TFDS_" + self.dataset_name
def _create_examples(self, split_name, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
if split_name not in self.dataset:
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
......@@ -731,7 +846,7 @@ class WnliProcessor(DataProcessor):
return "WNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
......@@ -777,7 +892,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -792,7 +907,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % i
......@@ -807,7 +922,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
for (i, line) in enumerate(lines):
for i, line in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i
......@@ -833,45 +948,104 @@ class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Args:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training and testing data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremePawsxProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is None:
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.only_use_en_dev:
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
examples_by_lang = {}
for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
for i, line in enumerate(lines):
guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "0"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
......@@ -891,45 +1065,111 @@ class XtremeXnliProcessor(DataProcessor):
"ur", "vi", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode,
translated_data_dir=None,
only_use_en_dev=True):
"""See base class.
Args:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super(XtremeXnliProcessor, self).__init__(process_text_fn)
self.translated_data_dir = translated_data_dir
self.only_use_en_dev = only_use_en_dev
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is None:
for i, line in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-train",
f"en-{lang}-translated.tsv"))
for i, line in enumerate(lines):
guid = f"train-{lang}-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = self.process_text_fn(line[4])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.only_use_en_dev:
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
for i, line in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
else:
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
for i, line in enumerate(lines):
guid = f"dev-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
if label == self.process_text_fn("contradictory"):
label = self.process_text_fn("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
examples_by_lang = {}
for lang in self.supported_languages:
examples_by_lang[lang] = []
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
for i, line in enumerate(lines):
guid = f"test-{lang}-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
if self.translated_data_dir is not None:
for lang in self.supported_languages:
if lang == "en":
continue
examples_by_lang[f"{lang}-en"] = []
lines = self._read_tsv(
os.path.join(self.translated_data_dir, "translate-test",
f"test-{lang}-en-translated.tsv"))
for i, line in enumerate(lines):
guid = f"test-{lang}-en-{i}"
text_a = self.process_text_fn(line[2])
text_b = self.process_text_fn(line[3])
label = "contradiction"
examples_by_lang[f"{lang}-en"].append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
......@@ -965,6 +1205,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
seg_id_a = 0
seg_id_b = 1
seg_id_cls = 0
seg_id_pad = 0
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
......@@ -986,19 +1231,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
segment_ids.append(seg_id_cls)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
segment_ids.append(seg_id_a)
tokens.append("[SEP]")
segment_ids.append(0)
segment_ids.append(seg_id_a)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
segment_ids.append(seg_id_b)
tokens.append("[SEP]")
segment_ids.append(1)
segment_ids.append(seg_id_b)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
......@@ -1010,7 +1255,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
segment_ids.append(seg_id_pad)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
......@@ -1027,7 +1272,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight)
logging.info("int_iden: %s", str(example.int_iden))
logging.info("example_id: %s", example.example_id)
feature = InputFeatures(
input_ids=input_ids,
......@@ -1036,11 +1281,86 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id=label_id,
is_real_example=True,
weight=example.weight,
int_iden=example.int_iden)
example_id=example.example_id)
return feature
class AXgProcessor(DataProcessor):
"""Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "AX-g.jsonl")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "AXg"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for line in lines:
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class SuperGLUERTEProcessor(DataProcessor):
"""Processor for the RTE dataset (SuperGLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
def get_labels(self):
"""See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "RTESuperGLUE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training/dev/test sets."""
examples = []
for i, line in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line["premise"])
text_b = self.process_text_fn(line["hypothesis"])
if set_type == "test":
label = "entailment"
else:
label = self.process_text_fn(line["label"])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def file_based_convert_examples_to_features(examples,
label_list,
max_seq_length,
......@@ -1052,7 +1372,7 @@ def file_based_convert_examples_to_features(examples,
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d", ex_index, len(examples))
......@@ -1079,8 +1399,10 @@ def file_based_convert_examples_to_features(examples,
[int(feature.is_real_example)])
if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight])
if feature.int_iden is not None:
features["int_iden"] = create_int_feature([feature.int_iden])
if feature.example_id is not None:
features["example_id"] = create_int_feature([feature.example_id])
else:
features["example_id"] = create_int_feature([ex_index])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
......@@ -1113,7 +1435,7 @@ def generate_tf_record_from_data_file(processor,
max_seq_length=128):
"""Generates and saves training data into a tf record file.
Arguments:
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval/test data to process.
......@@ -1137,13 +1459,15 @@ def generate_tf_record_from_data_file(processor,
label_type = getattr(processor, "label_type", None)
is_regression = getattr(processor, "is_regression", False)
has_sample_weights = getattr(processor, "weight_key", False)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer,
train_data_output_path, label_type)
num_training_data = len(train_input_data_examples)
num_training_data = 0
if train_data_output_path:
train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples,
label_list, max_seq_length,
tokenizer, train_data_output_path,
label_type)
num_training_data = len(train_input_data_examples)
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(data_dir)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""BERT finetuning task dataset generator."""
import functools
import json
import os
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
......@@ -49,41 +46,60 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI",
"XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# XNLI task specific flag.
flags.DEFINE_enum(
"classification_task_name", "MNLI", [
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
"AX-g", "SUPERGLUE-RTE"
], "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# MNLI task-specific flag.
flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
"The type of MNLI dataset.")
# XNLI task-specific flag.
flags.DEFINE_string(
"xnli_language", "en",
"Language of training data for XNIL task. If the value is 'all', the data "
"Language of training data for XNLI task. If the value is 'all', the data "
"of all languages will be used for training.")
# PAWS-X task specific flag.
# PAWS-X task-specific flag.
flags.DEFINE_string(
"pawsx_language", "en",
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of training data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.")
# Retrieva task specific flags
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags.DEFINE_string(
"translated_input_data_dir", None,
"The translated input data dir. Should contain the .tsv files (or other "
"data files) for the task.")
# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task specific flags
# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags.
flags.DEFINE_bool("tagging_only_use_en_train", True,
"Whether only use english training data in tagging.")
# BERT Squad task-specific flags.
flags.DEFINE_string(
"squad_data_file", None,
"The input data file in for generating training data for BERT squad task.")
flags.DEFINE_string(
"translated_squad_data_folder", None,
"The translated data folder for generating training data for BERT squad "
"task.")
flags.DEFINE_integer(
"doc_stride", 128,
"When splitting up a long document into chunks, how much stride to "
......@@ -98,6 +114,14 @@ flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
flags.DEFINE_bool(
"xlnet_format", False,
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order.")
# XTREME specific flags.
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
......@@ -136,36 +160,35 @@ flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_enum(
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer.")
"tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
"Specifies the tokenizer implementation, i.e., whether to use WordPiece "
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
"while ALBERT uses SentencePiece tokenizer.")
flags.DEFINE_string("tfds_params", "",
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
flags.DEFINE_string(
"tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation).")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.classification_task_name
or FLAGS.tfds_params)
assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
FLAGS.tfds_params)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
......@@ -176,31 +199,51 @@ def generate_classifier_dataset():
max_seq_length=FLAGS.max_seq_length)
else:
processors = {
"ax":
classifier_data_lib.AxProcessor,
"cola":
classifier_data_lib.ColaProcessor,
"imdb":
classifier_data_lib.ImdbProcessor,
"mnli":
classifier_data_lib.MnliProcessor,
functools.partial(
classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
"mrpc":
classifier_data_lib.MrpcProcessor,
"qnli":
classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"qqp":
classifier_data_lib.QqpProcessor,
"rte":
classifier_data_lib.RteProcessor,
"sst-2":
classifier_data_lib.SstProcessor,
"sts-b":
classifier_data_lib.StsBProcessor,
"xnli":
functools.partial(classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
functools.partial(
classifier_data_lib.XnliProcessor,
language=FLAGS.xnli_language),
"paws-x":
functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
functools.partial(
classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language),
"wnli":
classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
functools.partial(
classifier_data_lib.XtremeXnliProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
functools.partial(
classifier_data_lib.XtremePawsxProcessor,
translated_data_dir=FLAGS.translated_input_data_dir,
only_use_en_dev=FLAGS.only_use_en_dev),
"ax-g":
classifier_data_lib.AXgProcessor,
"superglue-rte":
classifier_data_lib.SuperGLUERTEProcessor
}
task_name = FLAGS.classification_task_name.lower()
if task_name not in processors:
......@@ -219,20 +262,19 @@ def generate_classifier_dataset():
def generate_regression_dataset():
"""Generates regression dataset and returns input meta data."""
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
if FLAGS.tfds_params:
processor = classifier_data_lib.TfdsProcessor(
tfds_params=FLAGS.tfds_params,
process_text_fn=processor_text_fn)
tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
return classifier_data_lib.generate_tf_record_from_data_file(
processor,
None,
......@@ -248,28 +290,42 @@ def generate_regression_dataset():
def generate_squad_dataset():
"""Generates squad training dataset and returns input meta data."""
assert FLAGS.squad_data_file
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
input_file_path=FLAGS.squad_data_file,
vocab_file_path=FLAGS.vocab_file,
output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
doc_stride=FLAGS.doc_stride,
version_2_with_negative=FLAGS.version_2_with_negative,
xlnet_format=FLAGS.xlnet_format)
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
input_file_path=FLAGS.squad_data_file,
sp_model_file=FLAGS.sp_model_file,
output_path=FLAGS.train_data_output_path,
translated_input_folder=FLAGS.translated_squad_data_folder,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
doc_stride=FLAGS.doc_stride,
xlnet_format=FLAGS.xlnet_format,
version_2_with_negative=FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
......@@ -286,34 +342,38 @@ def generate_retrieval_dataset():
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
"panx":
functools.partial(
tagging_data_lib.PanxProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
"udpos":
functools.partial(
tagging_data_lib.UdposProcessor,
only_use_en_train=FLAGS.tagging_only_use_en_train,
only_use_en_dev=FLAGS.only_use_en_dev),
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
elif FLAGS.tokenization == "SentencePiece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
......@@ -323,12 +383,12 @@ def generate_tagging_dataset():
def main(_):
if FLAGS.tokenizer_impl == "word_piece":
if FLAGS.tokenization == "WordPiece":
if not FLAGS.vocab_file:
raise ValueError(
"FLAG vocab_file for word-piece tokenizer is not specified.")
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
assert FLAGS.tokenization == "SentencePiece"
if not FLAGS.sp_model_file:
raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import itertools
import random
# Import libraries
from absl import app
from absl import flags
from absl import logging
......@@ -48,10 +47,20 @@ flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer(
"max_ngram_size", None,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
flags.DEFINE_bool(
"gzip_compress", False,
"Whether to use `GZIP` compress option to get compressed TFRecord files.")
flags.DEFINE_bool(
"use_v2_feature_names", False,
"Whether to use the feature names consistent with the models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
......@@ -101,8 +110,8 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s."""
gzip_compress, use_v2_feature_names):
"""Creates TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(
......@@ -139,9 +148,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
if use_v2_feature_names:
features["input_word_ids"] = create_int_feature(input_ids)
features["input_type_ids"] = create_int_feature(segment_ids)
else:
features["input_ids"] = create_int_feature(input_ids)
features["segment_ids"] = create_int_feature(segment_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
......@@ -192,7 +206,8 @@ def create_training_instances(input_files,
masked_lm_prob,
max_predictions_per_seq,
rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
......@@ -229,7 +244,7 @@ def create_training_instances(input_files,
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
do_whole_word_mask, max_ngram_size))
rng.shuffle(instances)
return instances
......@@ -238,7 +253,8 @@ def create_training_instances(input_files,
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask=False):
do_whole_word_mask=False,
max_ngram_size=None):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
......@@ -337,7 +353,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
do_whole_word_mask, max_ngram_size)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
......@@ -355,72 +371,238 @@ def create_instances_from_document(
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
def _window(iterable, size):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Args:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i = iter(iterable)
window = []
try:
for e in range(0, size):
window.append(next(i))
yield window
except StopIteration:
# handle the case where iterable's length is less than the window size.
return
for e in i:
window = window[1:] + [e]
yield window
def _contiguous(sorted_grams):
"""Test whether a sequence of grams is contiguous.
Args:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for a, b in _window(sorted_grams, 2):
if a.end != b.begin:
return False
return True
def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Args:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if not grams:
return None
grams = sorted(grams)
num_tokens = grams[-1].end
# Ensure our grams are valid (i.e., they don't overlap).
for a, b in _window(grams, 2):
if a.end > b.begin:
raise ValueError("overlapping grams: {}".format(grams))
# Build map from n-gram length to list of n-grams.
ngrams = {i: [] for i in range(1, max_ngram_size+1)}
for gram_size in range(1, max_ngram_size+1):
for g in _window(grams, gram_size):
if _contiguous(g):
# Add an n-gram which spans these one-grams.
ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
# Shuffle each list of n-grams.
for v in ngrams.values():
rng.shuffle(v)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights = list(
itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
output_ngrams = []
# Keep a bitmask of which tokens have been masked.
masked_tokens = [False] * num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while (sum(masked_tokens) < max_masked_tokens and
sum(len(s) for s in ngrams.values())):
# Pick an n-gram size based on our weights.
sz = random.choices(range(1, max_ngram_size+1),
cum_weights=cummulative_weights)[0]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if sum(masked_tokens) + sz > max_masked_tokens:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams[sz].clear()
continue
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
# All of the n-grams of this size have been used.
if not ngrams[sz]:
continue
# Choose a random n-gram of the given size.
gram = ngrams[sz].pop()
num_gram_tokens = gram.end-gram.begin
# Check if this would add too many tokens.
if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
continue
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
# Check if any of the tokens in this gram have already been masked.
if sum(masked_tokens[gram.begin:gram.end]):
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
# Found a usable n-gram! Mark its tokens as masked and add it to return.
masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
output_ngrams.append(gram)
return output_ngrams
def _wordpieces_to_grams(tokens):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Args:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams = []
gram_start_pos = None
for i, token in enumerate(tokens):
if gram_start_pos is not None and token.startswith("##"):
continue
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, i))
if token not in ["[CLS]", "[SEP]"]:
gram_start_pos = i
else:
cand_indexes.append([i])
gram_start_pos = None
if gram_start_pos is not None:
grams.append(_Gram(gram_start_pos, len(tokens)))
return grams
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask,
max_ngram_size=None):
"""Creates the predictions for the masked LM objective."""
if do_whole_word_mask:
grams = _wordpieces_to_grams(tokens)
else:
# Here we consider each token to be a word to allow for sub-word masking.
if max_ngram_size:
raise ValueError("cannot use ngram masking without whole word masking")
grams = [_Gram(i, i+1) for i in range(0, len(tokens))
if tokens[i] not in ["[CLS]", "[SEP]"]]
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
num_to_predict, rng)
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
output_tokens = list(tokens)
for gram in masked_grams:
# 80% of the time, replace all n-gram tokens with [MASK]
if rng.random() < 0.8:
replacement_action = lambda idx: "[MASK]"
else:
# 10% of the time, keep all the original n-gram tokens.
if rng.random() < 0.5:
replacement_action = lambda idx: tokens[idx]
# 10% of the time, replace each n-gram token with a random word.
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
replacement_action = lambda idx: rng.choice(vocab_words)
output_tokens[index] = masked_token
for idx in range(gram.begin, gram.end):
output_tokens[idx] = replacement_action(idx)
masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
......@@ -467,7 +649,7 @@ def main(_):
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng, FLAGS.do_whole_word_mask)
rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
......@@ -476,7 +658,8 @@ def main(_):
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
FLAGS.gzip_compress,
FLAGS.use_v2_feature_names)
if __name__ == "__main__":
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.create_pretraining_data."""
import random
import tensorflow as tf
from official.nlp.data import create_pretraining_data as cpd
_VOCAB_WORDS = ["vocab_1", "vocab_2"]
class CreatePretrainingDataTest(tf.test.TestCase):
def assertTokens(self, input_tokens, output_tokens, masked_positions,
masked_labels):
# Ensure the masked positions are unique.
self.assertCountEqual(masked_positions, set(masked_positions))
# Ensure we can reconstruct the input from the output.
reconstructed_tokens = output_tokens
for pos, label in zip(masked_positions, masked_labels):
reconstructed_tokens[pos] = label
self.assertEqual(input_tokens, reconstructed_tokens)
# Ensure each label is valid.
for pos, label in zip(masked_positions, masked_labels):
output_token = output_tokens[pos]
if (output_token == "[MASK]" or output_token in _VOCAB_WORDS or
output_token == input_tokens[pos]):
continue
self.fail("invalid mask value: {}".format(output_token))
def test_wordpieces_to_grams(self):
tests = [
(["That", "cone"], [(0, 1), (1, 2)]),
(["That", "cone", "##s"], [(0, 1), (1, 3)]),
(["Swit", "##zer", "##land"], [(0, 3)]),
(["[CLS]", "Up", "##dog"], [(1, 3)]),
(["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
]
for inp, expected in tests:
output = cpd._wordpieces_to_grams(inp)
self.assertEqual(expected, output)
def test_window(self):
input_list = [1, 2, 3, 4]
window_outputs = [
(1, [[1], [2], [3], [4]]),
(2, [[1, 2], [2, 3], [3, 4]]),
(3, [[1, 2, 3], [2, 3, 4]]),
(4, [[1, 2, 3, 4]]),
(5, []),
]
for window, expected in window_outputs:
output = cpd._window(input_list, window)
self.assertEqual(expected, list(output))
def test_create_masked_lm_predictions(self):
tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
rng = random.Random(123)
for _ in range(0, 5):
output_tokens, masked_positions, masked_labels = (
cpd.create_masked_lm_predictions(
tokens=tokens,
masked_lm_prob=1.0,
max_predictions_per_seq=3,
vocab_words=_VOCAB_WORDS,
rng=rng,
do_whole_word_mask=False,
max_ngram_size=None))
self.assertEqual(len(masked_positions), 3)
self.assertEqual(len(masked_labels), 3)
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
def test_create_masked_lm_predictions_whole_word(self):
tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
rng = random.Random(345)
for _ in range(0, 5):
output_tokens, masked_positions, masked_labels = (
cpd.create_masked_lm_predictions(
tokens=tokens,
masked_lm_prob=1.0,
max_predictions_per_seq=3,
vocab_words=_VOCAB_WORDS,
rng=rng,
do_whole_word_mask=True,
max_ngram_size=None))
# since we can't get exactly three tokens without breaking a word we
# only take two.
self.assertEqual(len(masked_positions), 2)
self.assertEqual(len(masked_labels), 2)
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
# ensure that we took an entire word.
self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
def test_create_masked_lm_predictions_ngram(self):
tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
rng = random.Random(345)
for _ in range(0, 5):
output_tokens, masked_positions, masked_labels = (
cpd.create_masked_lm_predictions(
tokens=tokens,
masked_lm_prob=1.0,
max_predictions_per_seq=76,
vocab_words=_VOCAB_WORDS,
rng=rng,
do_whole_word_mask=True,
max_ngram_size=3))
self.assertEqual(len(masked_positions), 76)
self.assertEqual(len(masked_labels), 76)
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create LM TF examples for XLNet."""
import json
import math
import os
import random
from typing import Iterable, Mapping, List, Optional, Tuple
import unicodedata
# Import libraries
from absl import app
from absl import flags
from absl import logging
import dataclasses
import numpy as np
import tensorflow as tf
from official.nlp.bert import tokenization
special_symbols = {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<cls>": 3,
"<sep>": 4,
"<pad>": 5,
"<mask>": 6,
"<eod>": 7,
"<eop>": 8,
}
FLAGS = flags.FLAGS
flags.DEFINE_integer("seq_length", 512,
help="Sequence length.")
flags.DEFINE_integer("reuse_length", 256,
help="Number of token that can be reused as memory. "
"Could be half of `seq_len`.")
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"save_dir", None,
"Directory for saving processed data.")
flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer.")
flags.DEFINE_bool("use_eod_token", True,
"Whether or not to include EOD tokens.")
flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.")
flags.DEFINE_integer("num_cores_per_host", 16,
"The number of (TPU) cores per host.")
flags.DEFINE_string("prefix", "", "Filename prefix.")
flags.DEFINE_string("suffix", "", "Filename suffix.")
flags.DEFINE_integer("task_id", None,
"The id of the current task.")
flags.DEFINE_integer("num_tasks", None,
"The total number of tasks.")
flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.")
@dataclasses.dataclass
class TrainingInstance:
"""Representation of a single XLNet Pretraining instance."""
data: Iterable[int]
segment_ids: Iterable[int]
boundary_indices: Iterable[int]
label: int
def to_feature(self) -> Mapping[str, tf.train.Feature]:
feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x))
return dict(
input_word_ids=feat(self.data),
input_type_ids=feat(self.segment_ids),
boundary_indices=feat(self.boundary_indices),
label=feat([self.label]))
def to_example(self) -> tf.train.Example:
return tf.train.Example(
features=tf.train.Features(feature=self.to_feature()))
def __str__(self):
def seq_to_str(seq):
return " ".join([str(x) for x in seq])
s = ""
s += "tokens: %s\n" % seq_to_str(self.data)
s += "segment_ids: %s\n" % seq_to_str(self.segment_ids)
s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices)
s += "label: %s\n" % self.label
s += "\n"
return s
def __repr__(self):
return self.__str__()
def _preprocess_line(line: str, do_lower_case: bool = False) -> str:
"""Preprocesses an individual raw text line.
This function will:
- Remove extraneous spaces.
- Replace `` with ", and '' with ".
- Replaces accents.
- Applies lower casing.
Args:
line: The input line to preprocess.
do_lower_case: Whether or not to lower case the text.
Returns:
The preprocessed line.
"""
line = " ".join(line.split())
line = line.replace("``", "\"").replace("''", "\"")
# Replace accents.
line = unicodedata.normalize("NFKD", line)
line = "".join([c for c in line if not unicodedata.combining(c)])
if do_lower_case:
line = line.lower()
return line
def preprocess_and_tokenize_input_files(
input_files: Iterable[str],
tokenizer: tokenization.FullSentencePieceTokenizer,
use_eod: bool = True,
do_lower_case: bool = False,
log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]:
"""Preprocesses and encodes raw text from input files.
This function preprocesses raw text and encodes them into tokens using a
`SentencePieceModel` tokenization method. This also provides the sentence
indicator for each token.
Args:
input_files: The list of input file names.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
not included.
do_lower_case: Whether or not to apply lower casing during raw text
preprocessing.
log_example_freq: The optional field for how many lines to process before
emitting an info log.
Returns:
The preprocessed list. Each entry in the list is a tuple consisting of
the token IDs and the sentence IDs.
"""
all_data = []
eod_symbol = special_symbols["<eod>"]
total_number_of_lines = 0
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
line_count = 0
logging.info("Preprocessing %s", input_file)
all_tokens = []
all_sentence_ids = []
sentence_id = True
with tf.io.gfile.GFile(input_file, "rb") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line_count += 1
if line_count % log_example_freq == 0:
logging.info("Loading line %d", line_count)
line = line.strip()
if not line:
if use_eod:
token_ids = [eod_symbol]
sentence_id = not sentence_id
else:
continue
else:
preprocessed_line = _preprocess_line(
line=line, do_lower_case=do_lower_case)
token_ids = tokenization.encode_ids(
sp_model=tokenizer.sp_model, text=preprocessed_line)
all_tokens.extend(token_ids)
all_sentence_ids.extend([sentence_id] * len(token_ids))
sentence_id = not sentence_id
logging.info("Finished processing %s. Number of lines: %d",
input_file, line_count)
if line_count == 0:
continue
total_number_of_lines += line_count
all_tokens = np.array(all_tokens, dtype=np.int64)
all_sentence_ids = np.array(all_sentence_ids, dtype=np.bool)
all_data.append((all_tokens, all_sentence_ids))
logging.info("Completed text preprocessing. Total number of lines: %d",
total_number_of_lines)
return all_data
def _reshape_to_batch_dimensions(
tokens: np.array,
sentence_ids: np.array,
per_host_batch_size: int) -> Tuple[np.array, np.array]:
"""Truncates and reshapes input data with a batch major dimension.
Args:
tokens: The input token ids. This should have the same shape as
`sentence_ids`.
sentence_ids: The input sentence ids. This should have the same shape as
`token_ids`.
per_host_batch_size: The target per-host batch size.
Returns:
The tuple of reshaped tokens and sentence_ids.
"""
num_steps = len(tokens) // per_host_batch_size
truncated_data_length = num_steps * per_host_batch_size
logging.info("per_host_batch_size: %d", per_host_batch_size)
logging.info("num_steps: %d", num_steps)
def truncate_and_reshape(a):
return a[:truncated_data_length].reshape((per_host_batch_size, num_steps))
return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids))
def _create_a_and_b_segments(
tokens: np.array,
sentence_ids: np.array,
begin_index: int,
total_length: int,
no_cut_probability: float = 0.5):
"""Splits segments A and B from a single instance of tokens and sentence ids.
Args:
tokens: The 1D input token ids. This represents an individual entry within a
batch.
sentence_ids: The 1D input sentence ids. This represents an indivdual entry
within a batch. This should be the same length as `tokens`.
begin_index: The reference beginning index to split data.
total_length: The target combined length of segments A and B.
no_cut_probability: The probability of not cutting a segment despite
a cut possibly existing.
Returns:
A tuple consisting of A data, B data, and label.
"""
data_length = tokens.shape[0]
if begin_index + total_length >= data_length:
logging.info("[_create_segments]: begin_index %d + total_length %d >= "
"data_length %d", begin_index, total_length, data_length)
return None
end_index = begin_index + 1
cut_indices = []
# Identify all indices where sentence IDs change from one to the next.
while end_index < data_length:
if sentence_ids[end_index] != sentence_ids[end_index - 1]:
if end_index - begin_index >= total_length:
break
cut_indices.append(end_index)
end_index += 1
a_begin = begin_index
if not cut_indices or random.random() < no_cut_probability:
# Segments A and B are contained within the same sentence.
label = 0
if not cut_indices:
a_end = end_index
else:
a_end = random.choice(cut_indices)
b_length = max(1, total_length - (a_end - a_begin))
b_begin = random.randint(0, data_length - 1 - b_length)
b_end = b_begin + b_length
while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]:
b_begin -= 1
while (b_end < data_length - 1 and
sentence_ids[b_end - 1] == sentence_ids[b_end]):
b_end += 1
else:
# Segments A and B are different sentences.
label = 1
a_end = random.choice(cut_indices)
b_begin = a_end
b_end = end_index
while a_end - a_begin + b_end - b_begin > total_length:
if a_end - a_begin > b_end - b_begin:
# Delete only the right side for the LM objective.
a_end -= 1
else:
b_end -= 1
if a_end >= data_length or b_end >= data_length:
logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d",
a_end, b_end, data_length)
return None
a_data = tokens[a_begin: a_end]
b_data = tokens[b_begin: b_end]
return a_data, b_data, label
def _is_functional_piece(piece: str) -> bool:
return piece != "<unk>" and piece.startswith("<") and piece.endswith(">")
def _is_start_piece(piece: str) -> bool:
special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
if (piece.startswith("▁") or piece in special_pieces):
return True
else:
return False
def _get_boundary_indices(
data: np.array,
tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array:
"""Gets the boundary indices of whole words."""
seq_length = len(data)
boundary_indices = []
for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())):
if _is_start_piece(piece) and not _is_functional_piece(piece):
boundary_indices.append(index)
boundary_indices.append(seq_length)
return boundary_indices
def _convert_tokens_to_instances(
tokens: np.array,
sentence_ids: np.array,
per_host_batch_size: int,
seq_length: int,
reuse_length: int,
bi_data: bool,
tokenizer: tokenization.FullSentencePieceTokenizer,
num_cores_per_host: int = 0,
logging_frequency: int = 500) -> List[TrainingInstance]:
"""Converts tokens and sentence IDs into individual training instances.
The format of data in the XLNet pretraining task is very similar to the
BERT pretraining task. Two segments A and B are randomly sampled, and the
contatenation of A and B into a single sequence is used to perform
language modeling.
To create an XLNet Pretraining instance from a single long sequence, S:
- Create a segment of length `reuse_length`. This first segment represents
past tokens. During modeling, this segment is used to cache obtained
content representations for the segment recurrence mechanism.
- Similar to BERT, create a segment of length `seq_length` - `reuse_length`
composed of A and B segments.
For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
Args:
tokens: All tokens concatenated into a single list.
sentence_ids: All sentence IDs concatenated into a single list.
per_host_batch_size: The target batch size per host.
seq_length: The max sequence length.
reuse_length: The number of tokens to use from the previous segment.
bi_data: Whether or not to use bidirectional data.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
num_cores_per_host: The number of cores per host. This is required if
`bi_data` = `True`.
logging_frequency: The frequency at which to log status updates.
Returns:
A list of `TrainingInstance` objects.
"""
instances = []
per_core_batch_size = (per_host_batch_size // num_cores_per_host
if bi_data else None)
if bi_data:
logging.info("Bi-directional data enabled.")
assert per_host_batch_size % (2 * num_cores_per_host) == 0
forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=per_host_batch_size // 2)
forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1)
forward_tokens = forward_tokens.reshape(forward_data_shape)
forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape)
backwards_tokens = forward_tokens[:, :, :, ::-1]
backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1]
tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape(
per_host_batch_size, -1)
sentence_ids = np.concatenate(
[forward_sentence_ids, backwards_sentence_ids]).reshape(
per_host_batch_size, -1)
else:
logging.info("Bi-directional data disabled.")
tokens, sentence_ids = _reshape_to_batch_dimensions(
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=per_host_batch_size)
logging.info("Tokens shape: %s", tokens.shape)
data_length = tokens.shape[1]
sep = np.array([special_symbols["<sep>"]], dtype=np.int64)
cls = np.array([special_symbols["<cls>"]], dtype=np.int64)
# 2 sep, 1 cls
num_special_tokens = 3
data_index = 0
batch_number = 0
step_size = reuse_length if reuse_length else seq_length
num_batches = math.ceil(data_length / step_size)
while data_index + seq_length <= data_length:
if batch_number % logging_frequency == 0:
logging.info("Processing batch %d of %d", batch_number, num_batches)
for batch_index in range(per_host_batch_size):
previous_segment_tokens = tokens[
batch_index, data_index: data_index + reuse_length]
results = _create_a_and_b_segments(
tokens=tokens[batch_index],
sentence_ids=sentence_ids[batch_index],
begin_index=data_index + reuse_length,
total_length=seq_length - reuse_length - num_special_tokens)
if results is None:
logging.info("Stopping at data index: %d", data_index)
break
a_data, b_data, label = results
data = np.concatenate(
[previous_segment_tokens, a_data, sep, b_data, sep, cls])
a_length = a_data.shape[0]
b_length = b_data.shape[0]
segment_ids = ([0] * (reuse_length + a_length) + [0]
+ [1] * b_length + [1] + [2])
boundary_indices = _get_boundary_indices(tokenizer=tokenizer,
data=data)
assert len(data) == seq_length
assert len(segment_ids) == seq_length
assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test
instances.append(TrainingInstance(
data=data,
segment_ids=segment_ids,
boundary_indices=boundary_indices,
label=label))
batch_number += 1
data_index += step_size
return instances
def write_instances_to_tfrecord(
instances: Iterable[TrainingInstance],
save_path: str):
"""Writes instances to TFRecord."""
record_writer = tf.io.TFRecordWriter(save_path)
logging.info("Start writing to %s.", save_path)
for i, instance in enumerate(instances):
if i < 5:
logging.info("Instance %d: %s", i, str(instance))
record_writer.write(instance.to_example().SerializeToString())
record_writer.close()
logging.info("Done writing %s.", save_path)
def shuffle_and_combine_preprocessed_data(
all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]:
"""Shuffles and combines preprocessed token/sentence IDs from documents."""
document_permutation = np.random.permutation(len(all_data))
previous_sentence_id = None
all_tokens, all_sentence_ids = [], []
for document_index in document_permutation:
tokens, sentence_ids = all_data[document_index]
# pylint: disable=g-explicit-length-test
if len(tokens) == 0:
continue
if (previous_sentence_id is not None and
sentence_ids[0] == previous_sentence_id):
sentence_ids = np.logical_not(sentence_ids)
all_tokens.append(tokens)
all_sentence_ids.append(sentence_ids)
previous_sentence_id = sentence_ids[-1]
return np.concatenate(all_tokens), np.concatenate(all_sentence_ids)
def get_tfrecord_name(
per_host_batch_size: int,
num_cores_per_host: int,
seq_length: int,
bi_data: bool,
reuse_length: int,
do_lower_case: bool,
use_eod_token: bool,
prefix: str = "",
suffix: str = "",
pass_id: int = 0,
num_passes: int = 1,
task_id: int = None,
num_tasks: int = None) -> str:
"""Formats the resulting TFRecord name based on provided inputs."""
components = []
if prefix:
components.append(prefix)
components.append("seqlen-{}".format(seq_length))
if reuse_length == 0:
components.append("memless")
else:
components.append("reuse-{}".format(reuse_length))
components.append("bs-{}".format(per_host_batch_size))
components.append("cores-{}".format(num_cores_per_host))
if do_lower_case:
components.append("uncased")
else:
components.append("cased")
if use_eod_token:
components.append("eod")
if bi_data:
components.append("bi")
else:
components.append("uni")
if suffix:
components.append(suffix)
s = "_".join(components) + ".tfrecord"
if num_passes == 1 and task_id is None:
return s
if task_id is None:
num_tasks = 1
task_id = 0
current_shard = task_id * num_passes + pass_id
total_shards = num_tasks * num_passes
return s + "-{}-of-{}".format(current_shard, total_shards)
def create_tfrecords(
tokenizer: tokenization.FullSentencePieceTokenizer,
input_file_or_files: str,
use_eod_token: bool,
do_lower_case: bool,
per_host_batch_size: int,
seq_length: int,
reuse_length: int,
bi_data: bool,
num_cores_per_host: int,
save_dir: str,
prefix: str = "",
suffix: str = "",
num_tasks: Optional[int] = None,
task_id: Optional[int] = None,
num_passes: int = 1):
"""Runs the end-to-end preprocessing pipeline."""
logging.info("Input configuration:")
logging.info("input file(s): %s", input_file_or_files)
logging.info("use_eod_token: %s", use_eod_token)
logging.info("do_lower_case: %s", do_lower_case)
logging.info("per_host_batch_size: %d", per_host_batch_size)
logging.info("seq_length: %d", seq_length)
logging.info("reuse_length: %d", reuse_length)
logging.info("bi_data: %s", bi_data)
logging.info("num_cores_per_host: %d", num_cores_per_host)
logging.info("save_dir: %s", save_dir)
if task_id is not None and num_tasks is not None:
logging.info("task_id: %d", task_id)
logging.info("num_tasks: %d", num_tasks)
input_files = []
for input_pattern in input_file_or_files.split(","):
input_files.extend(tf.io.gfile.glob(input_pattern))
logging.info("*** Reading from input files ***")
for input_file in input_files:
logging.info(" %s", input_file)
logging.info("Shuffling the files with a fixed random seed.")
np.random.shuffle(input_files)
if num_tasks is not None:
assert task_id is not None
logging.info("Total number of input files: %d", len(input_files))
logging.info("Splitting into %d shards of %d files each.",
num_tasks, len(input_files) // num_tasks)
input_files = input_files[task_id::num_tasks]
all_data = preprocess_and_tokenize_input_files(
input_files=input_files,
tokenizer=tokenizer,
use_eod=use_eod_token,
do_lower_case=do_lower_case)
for pass_id in range(num_passes):
logging.info("Beginning pass %d of %d", pass_id, num_passes)
tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data)
assert len(tokens) == len(sentence_ids)
filename = get_tfrecord_name(
per_host_batch_size=per_host_batch_size,
num_cores_per_host=num_cores_per_host,
seq_length=seq_length,
bi_data=bi_data,
use_eod_token=use_eod_token,
reuse_length=reuse_length,
do_lower_case=do_lower_case,
prefix=prefix,
suffix=suffix,
pass_id=pass_id,
num_passes=num_passes,
num_tasks=num_tasks,
task_id=task_id)
save_path = os.path.join(save_dir, filename)
if os.path.exists(save_path):
# If the path already exists, then we were probably preempted but
# previously wrote this file.
logging.info("%s already exists, skipping this batch.", save_path)
else:
instances = _convert_tokens_to_instances(
tokenizer=tokenizer,
tokens=tokens,
sentence_ids=sentence_ids,
per_host_batch_size=per_host_batch_size,
seq_length=seq_length,
reuse_length=reuse_length,
bi_data=bi_data,
num_cores_per_host=num_cores_per_host)
write_instances_to_tfrecord(instances=instances, save_path=save_path)
if task_id is None or task_id == 0:
corpus_info = {
"vocab_size": 32000,
"per_host_batch_size": per_host_batch_size,
"num_cores_per_host": num_cores_per_host,
"seq_length": seq_length,
"reuse_length": reuse_length,
"do_lower_case": do_lower_case,
"bi_data": bi_data,
"use_eod_token": use_eod_token,
}
corpus_fname = os.path.basename(filename) + ".json"
corpus_destination = os.path.join(save_dir, corpus_fname)
logging.info("Saving corpus info to %s", corpus_destination)
with tf.io.gfile.GFile(corpus_destination, "w") as fp:
json.dump(corpus_info, fp)
def main(_):
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
create_tfrecords(
tokenizer=tokenizer,
input_file_or_files=FLAGS.input_file,
use_eod_token=FLAGS.use_eod_token,
do_lower_case=FLAGS.do_lower_case,
per_host_batch_size=FLAGS.per_host_batch_size,
seq_length=FLAGS.seq_length,
reuse_length=FLAGS.reuse_length,
bi_data=FLAGS.bi_data,
num_cores_per_host=FLAGS.num_cores_per_host,
save_dir=FLAGS.save_dir,
prefix=FLAGS.prefix,
suffix=FLAGS.suffix,
num_tasks=FLAGS.num_tasks,
task_id=FLAGS.task_id,
num_passes=FLAGS.num_passes)
if __name__ == "__main__":
np.random.seed(0)
logging.set_verbosity(logging.INFO)
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment