Commit 21b73d22 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317010998
parent a3263c0f
......@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A multi-head BERT encoder network for pretraining."""
"""Multi-head BERT encoder network with classification heads.
Includes configurations and instantiation methods.
"""
from typing import List, Optional, Text
import dataclasses
......@@ -24,7 +27,6 @@ from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer
......@@ -47,43 +49,34 @@ class BertPretrainerConfig(base_config.Config):
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_from_cfg(
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None):
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = networks.TransformerEncoder(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range))
if config.cls_heads:
classification_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
]
else:
classification_heads = []
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
classification_heads=classification_heads)
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task."""
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
......@@ -95,15 +88,15 @@ class BertPretrainDataConfig(cfg.DataConfig):
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task."""
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class BertSentencePredictionDataConfig(cfg.DataConfig):
"""Data of sentence prediction dataset."""
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
......@@ -111,10 +104,29 @@ class BertSentencePredictionDataConfig(cfg.DataConfig):
@dataclasses.dataclass
class BertSentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev data of MNLI sentence prediction dataset."""
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
......@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_from_cfg(config)
_ = bert.instantiate_bertpretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
......@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_from_cfg(config)
_ = bert.instantiate_bertpretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
......@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_from_cfg(config)
_ = bert.instantiate_bertpretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
......@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_from_cfg(config)
encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(),
["encoder", "next_sentence.pooler_dense"])
......
......@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configurations for Encoders."""
"""Transformer Encoders.
Includes configurations and instantiation methods.
"""
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import networks
@dataclasses.dataclass
......@@ -34,3 +40,22 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig) -> networks.TransformerEncoder:
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_network
......@@ -51,11 +51,13 @@ class BertSpanLabeler(tf.keras.Model):
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
# when we construct the Model object at the end of init.
......@@ -89,6 +91,10 @@ class BertSpanLabeler(tf.keras.Model):
super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
......
......@@ -40,7 +40,7 @@ class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
def build_model(self):
return bert.instantiate_from_cfg(self.task_config.network)
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network)
def build_losses(self,
labels,
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import logging
import dataclasses
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import input_pipeline
from official.nlp.configs import encoders
from official.nlp.modeling import models
@dataclasses.dataclass
class QuestionAnsweringConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
network: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.
TODO(lehou): Add post-processing.
"""
def __init__(self, params=cfg.TaskConfig):
super(QuestionAnsweringTask, self).__init__(params)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
def build_model(self):
if self._hub_module:
# TODO(lehou): maybe add the hub_module building logic to a util function.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(self._hub_module, trainable=True)
pooled_output, sequence_output = bert_model(
[input_word_ids, input_mask, input_type_ids])
encoder_network = tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.network)
return models.BertSpanLabeler(
network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.network.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions,
tf.cast(start_logits, dtype=tf.float32),
from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions,
tf.cast(end_logits, dtype=tf.float32),
from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32))
return (x, y)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
batch_size = input_context.get_per_replica_batch_size(
params.global_batch_size) if input_context else params.global_batch_size
# TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset = input_pipeline.create_squad_dataset(
params.input_path,
params.seq_length,
batch_size,
is_training=params.is_training,
input_pipeline_context=input_context)
return dataset
def build_metrics(self, training=None):
del training
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(
name='start_position_accuracy'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='end_position_accuracy'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
start_logits, end_logits = model_outputs
metrics['start_position_accuracy'].update_state(
labels['start_positions'], start_logits)
metrics['end_position_accuracy'].update_state(
labels['end_positions'], end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits, end_logits = model_outputs
compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={'start_positions': start_logits, 'end_positions': end_logits})
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.question_answering."""
import functools
import os
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.tasks import question_answering
class QuestionAnsweringTaskTest(tf.test.TestCase):
def setUp(self):
super(QuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1)
self._train_data_config = bert.QADataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config):
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
def test_task(self):
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=self._encoder_config,
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
saved_path = ckpt.save(self.get_temp_dir())
config = question_answering.QuestionAnsweringConfig(
init_checkpoint=saved_path,
network=self._encoder_config,
train_data=self._train_data_config)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
task.initialize(model)
def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig(
network=self._encoder_config,
train_data=self._train_data_config)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
train_step=task.train_step,
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
dataset = task.build_inputs(config.train_data)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn("loss", logs.history)
self.assertIn("start_positions_accuracy", logs.history)
self.assertIn("end_positions_accuracy", logs.history)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
_, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
return hub_destination
def test_task_with_hub(self):
hub_module_url = self._export_bert_tfhub()
config = question_answering.QuestionAnsweringConfig(
hub_module_url=hub_module_url,
network=self._encoder_config,
train_data=self._train_data_config)
self._run_task(config)
if __name__ == "__main__":
tf.test.main()
......@@ -34,7 +34,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
init_checkpoint: str = ''
hub_module_url: str = ''
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0,
num_masked_tokens=0, # No masked language modeling head.
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
......@@ -74,10 +74,10 @@ class SentencePredictionTask(base_task.Task):
encoder_from_hub = tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
return bert.instantiate_from_cfg(
return bert.instantiate_bertpretrainer_from_cfg(
self.task_config.network, encoder_network=encoder_from_hub)
else:
return bert.instantiate_from_cfg(self.task_config.network)
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
......
......@@ -27,6 +27,19 @@ from official.nlp.tasks import sentence_prediction
class SentencePredictionTaskTest(tf.test.TestCase):
def setUp(self):
super(SentencePredictionTaskTest, self).setUp()
self._network_config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="sentence_prediction")
])
self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config):
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
......@@ -44,16 +57,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self):
config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
network=bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="sentence_prediction")
]),
train_data=bert.BertSentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1))
network=self._network_config,
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
metrics = task.build_metrics()
......@@ -73,12 +78,27 @@ class SentencePredictionTaskTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_from_cfg(pretrain_cfg)
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig(
network=self._network_config,
train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
train_step=task.train_step,
metrics=task.build_metrics())
dataset = task.build_inputs(config.train_data)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn("loss", logs.history)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
......@@ -106,15 +126,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub()
config = sentence_prediction.SentencePredictionConfig(
hub_module_url=hub_module_url,
network=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="sentence_prediction")
]),
train_data=bert.BertSentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=10))
network=self._network_config,
train_data=self._train_data_config)
self._run_task(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment