Commit e8659893 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319839860
parent 7b91ccb1
......@@ -42,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
......@@ -55,16 +54,15 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg(
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
......
......@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
......@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
......@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_bertpretrainer_from_cfg(config)
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
......@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_bertpretrainer_from_cfg(config)
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -43,23 +44,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
"""
def __init__(self,
network,
num_classes,
initializer='glorot_uniform',
output='logits',
dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
'use_encoder_pooler': use_encoder_pooler,
}
# We want to use the inputs of the passed network as the inputs to this
......@@ -67,6 +70,7 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init.
inputs = network.inputs
if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs)
......@@ -76,13 +80,26 @@ class BertClassifier(tf.keras.Model):
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
output='logits',
name='sentence_prediction')
predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
......
......@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier(
test_network,
num_classes=num_classes)
test_network, num_classes=num_classes)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros', output='predictions')
test_network, num_classes=4, initializer='zeros')
# Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config()
......
......@@ -147,11 +147,9 @@ class BertPretrainerV2(tf.keras.Model):
(Experimental).
Adds the masked language model head and optional classification heads upon the
transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM
head.
transformer encoder.
Arguments:
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If
......@@ -169,7 +167,6 @@ class BertPretrainerV2(tf.keras.Model):
def __init__(
self,
num_masked_tokens: int,
encoder_network: tf.keras.Model,
mlm_activation=None,
mlm_initializer='glorot_uniform',
......@@ -179,7 +176,6 @@ class BertPretrainerV2(tf.keras.Model):
self._self_setattr_tracking = False
self._config = {
'encoder_network': encoder_network,
'num_masked_tokens': num_masked_tokens,
'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads,
'name': name,
......@@ -195,16 +191,13 @@ class BertPretrainerV2(tf.keras.Model):
raise ValueError('Classification heads should have unique names.')
outputs = dict()
if num_masked_tokens > 0:
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(num_masked_tokens,),
name='masked_lm_positions',
dtype=tf.int32)
shape=(None,), name='masked_lm_positions', dtype=tf.int32)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
......
......@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network.
num_token_predictions = 2
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=num_token_predictions)
encoder_network=test_network)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=2)
encoder_network=test_network)
# Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config()
......
......@@ -41,7 +41,7 @@ class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
def build_model(self):
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
return bert.instantiate_pretrainer_from_cfg(self.task_config.model)
def build_losses(self,
labels,
......
......@@ -30,7 +30,6 @@ class MLMTaskTest(tf.test.TestCase):
init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
......
......@@ -23,6 +23,7 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
......@@ -35,6 +36,13 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
@dataclasses.dataclass
class QuestionAnsweringConfig(cfg.TaskConfig):
"""The model config."""
......@@ -44,8 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
n_best_size: int = 20
max_answer_length: int = 30
null_score_diff_threshold: float = 0.0
model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
......@@ -81,12 +88,12 @@ class QuestionAnsweringTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model)
self.task_config.model.encoder)
# Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler(
network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.initializer_range))
stddev=self.task_config.model.encoder.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
......
......@@ -93,19 +93,18 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=self._encoder_config,
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
saved_path = ckpt.save(self.get_temp_dir())
config = question_answering.QuestionAnsweringConfig(
init_checkpoint=saved_path,
model=self._encoder_config,
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config(
version_2_with_negative))
......@@ -113,7 +112,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig(
model=self._encoder_config,
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
task = question_answering.QuestionAnsweringTask(config)
......@@ -156,7 +155,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
hub_module_url = self._export_bert_tfhub()
config = question_answering.QuestionAnsweringConfig(
hub_module_url=hub_module_url,
model=self._encoder_config,
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
self._run_task(config)
......
......@@ -23,12 +23,23 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A classifier/regressor configuration."""
num_classes: int = 0
use_encoder_pooler: bool = False
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
"""The model config."""
......@@ -38,15 +49,8 @@ class SentencePredictionConfig(cfg.TaskConfig):
init_cls_pooler: bool = False
hub_module_url: str = ''
metric_type: str = 'accuracy'
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, # No masked language modeling head.
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=3,
dropout_rate=0.1,
name='sentence_prediction')
])
# Defines the concrete model config at instantiation time.
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
......@@ -68,17 +72,22 @@ class SentencePredictionTask(base_task.Task):
def build_model(self):
if self._hub_module:
encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
return bert.instantiate_bertpretrainer_from_cfg(
self.task_config.model, encoder_network=encoder_from_hub)
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder)
# Currently, we only supports bert-style sentence prediction finetuning.
return models.BertClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels,
tf.cast(model_outputs['sentence_prediction'], tf.float32),
from_logits=True)
labels, tf.cast(model_outputs, tf.float32), from_logits=True)
if aux_losses:
loss += tf.add_n(aux_losses)
......@@ -112,10 +121,10 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, model_outputs['sentence_prediction'])
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
compiled_metrics.update_state(labels, model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy':
......@@ -129,15 +138,13 @@ class SentencePredictionTask(base_task.Task):
if self.metric_type == 'matthews_corrcoef':
logs.update({
'sentence_prediction':
tf.expand_dims(
tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0),
tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0),
'labels':
labels,
})
if self.metric_type == 'pearson_spearman_corr':
logs.update({
'sentence_prediction': outputs['sentence_prediction'],
'sentence_prediction': outputs,
'labels': labels,
})
return logs
......
......@@ -37,16 +37,10 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
input_path="dummy", seq_length=128, global_batch_size=1))
def get_model_config(self, num_classes):
return bert.BertPretrainerConfig(
return sentence_prediction.ModelConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10,
num_classes=num_classes,
name="sentence_prediction")
])
num_classes=num_classes)
def _run_task(self, config):
task = sentence_prediction.SentencePredictionTask(config)
......@@ -81,12 +75,11 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
......
......@@ -25,6 +25,7 @@ import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
......@@ -32,14 +33,22 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
head_dropout: float = 0.1
head_initializer_range: float = 0.02
@dataclasses.dataclass
class TaggingConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
model: ModelConfig = ModelConfig()
# The real class names, the order of which should match real label id.
# Note that a word may be tokenized into multiple word_pieces tokens, and
......@@ -93,14 +102,14 @@ class TaggingTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model)
self.task_config.model.encoder)
return models.BertTokenClassifier(
network=encoder_network,
num_classes=len(self.task_config.class_names),
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.initializer_range),
dropout_rate=self.task_config.model.dropout_rate,
stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.head_dropout,
output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
......
......@@ -56,7 +56,7 @@ class TaggingTest(tf.test.TestCase):
config = tagging.TaggingConfig(
init_checkpoint=saved_path,
model=self._encoder_config,
model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config)
......@@ -72,7 +72,7 @@ class TaggingTest(tf.test.TestCase):
def test_task_with_fit(self):
config = tagging.TaggingConfig(
model=self._encoder_config,
model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"])
......@@ -115,14 +115,13 @@ class TaggingTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub()
config = tagging.TaggingConfig(
hub_module_url=hub_module_url,
model=self._encoder_config,
class_names=["O", "B-PER", "I-PER"],
train_data=self._train_data_config)
self._run_task(config)
def test_seqeval_metrics(self):
config = tagging.TaggingConfig(
model=self._encoder_config,
model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment