Commit e8659893 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319839860
parent 7b91ccb1
...@@ -42,7 +42,6 @@ class ClsHeadConfig(base_config.Config): ...@@ -42,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
num_masked_tokens: int = 76
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
...@@ -55,16 +54,15 @@ def instantiate_classification_heads_from_cfgs( ...@@ -55,16 +54,15 @@ def instantiate_classification_heads_from_cfgs(
] if cls_head_configs else [] ] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg( def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2: ) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg) encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
......
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_bertpretrainer_from_cfg(config) _ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_bertpretrainer_from_cfg(config) encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements( self.assertSameElements(
encoder.checkpoint_items.keys(), encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"]) ["encoder", "masked_lm", "next_sentence.pooler_dense"])
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -43,23 +44,25 @@ class BertClassifier(tf.keras.Model): ...@@ -43,23 +44,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or dropout_rate: The dropout probability of the cls head.
'predictions'. use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
""" """
def __init__(self, def __init__(self,
network, network,
num_classes, num_classes,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._network = network
self._config = { self._config = {
'network': network, 'network': network,
'num_classes': num_classes, 'num_classes': num_classes,
'initializer': initializer, 'initializer': initializer,
'output': output, 'use_encoder_pooler': use_encoder_pooler,
} }
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
...@@ -67,22 +70,36 @@ class BertClassifier(tf.keras.Model): ...@@ -67,22 +70,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init. # when we construct the Model object at the end of init.
inputs = network.inputs inputs = network.inputs
# Because we have a copy of inputs to create this Model object, we can if use_encoder_pooler:
# invoke the Network object with its own input tensors to start the Model. # Because we have a copy of inputs to create this Model object, we can
_, cls_output = network(inputs) # invoke the Network object with its own input tensors to start the Model.
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) _, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification( self.classifier = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output=output, output='logits',
name='classification') name='sentence_prediction')
predictions = self.classifier(cls_output) predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, test_network, num_classes=num_classes)
num_classes=num_classes)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=4, initializer='zeros', output='predictions') test_network, num_classes=4, initializer='zeros')
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config() config = bert_trainer_model.get_config()
......
...@@ -147,11 +147,9 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -147,11 +147,9 @@ class BertPretrainerV2(tf.keras.Model):
(Experimental). (Experimental).
Adds the masked language model head and optional classification heads upon the Adds the masked language model head and optional classification heads upon the
transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM transformer encoder.
head.
Arguments: Arguments:
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If mlm_activation: The activation (if any) to use in the masked LM network. If
...@@ -169,7 +167,6 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -169,7 +167,6 @@ class BertPretrainerV2(tf.keras.Model):
def __init__( def __init__(
self, self,
num_masked_tokens: int,
encoder_network: tf.keras.Model, encoder_network: tf.keras.Model,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -179,7 +176,6 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -179,7 +176,6 @@ class BertPretrainerV2(tf.keras.Model):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config = { self._config = {
'encoder_network': encoder_network, 'encoder_network': encoder_network,
'num_masked_tokens': num_masked_tokens,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
'classification_heads': classification_heads, 'classification_heads': classification_heads,
'name': name, 'name': name,
...@@ -195,19 +191,16 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -195,19 +191,16 @@ class BertPretrainerV2(tf.keras.Model):
raise ValueError('Classification heads should have unique names.') raise ValueError('Classification heads should have unique names.')
outputs = dict() outputs = dict()
if num_masked_tokens > 0: self.masked_lm = layers.MaskedLM(
self.masked_lm = layers.MaskedLM( embedding_table=self.encoder_network.get_embedding_table(),
embedding_table=self.encoder_network.get_embedding_table(), activation=mlm_activation,
activation=mlm_activation, initializer=mlm_initializer,
initializer=mlm_initializer, name='cls/predictions')
name='cls/predictions') masked_lm_positions = tf.keras.layers.Input(
masked_lm_positions = tf.keras.layers.Input( shape=(None,), name='masked_lm_positions', dtype=tf.int32)
shape=(num_masked_tokens,), inputs.append(masked_lm_positions)
name='masked_lm_positions', outputs['lm_output'] = self.masked_lm(
dtype=tf.int32) sequence_output, masked_positions=masked_lm_positions)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
......
...@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_token_predictions = 2
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=num_token_predictions) encoder_network=test_network)
num_token_predictions = 20
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
encoder_network=test_network, num_masked_tokens=2) encoder_network=test_network)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config() config = bert_trainer_model.get_config()
......
...@@ -41,7 +41,7 @@ class MaskedLMTask(base_task.Task): ...@@ -41,7 +41,7 @@ class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
def build_model(self): def build_model(self):
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model) return bert.instantiate_pretrainer_from_cfg(self.task_config.model)
def build_losses(self, def build_losses(self,
labels, labels,
......
...@@ -30,7 +30,6 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -30,7 +30,6 @@ class MLMTaskTest(tf.test.TestCase):
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig( model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
......
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import squad_evaluate_v1_1 from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0 from official.nlp.bert import squad_evaluate_v2_0
...@@ -35,6 +36,13 @@ from official.nlp.modeling import models ...@@ -35,6 +36,13 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
@dataclasses.dataclass @dataclasses.dataclass
class QuestionAnsweringConfig(cfg.TaskConfig): class QuestionAnsweringConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
...@@ -44,8 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -44,8 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
n_best_size: int = 20 n_best_size: int = 20
max_answer_length: int = 30 max_answer_length: int = 30
null_score_diff_threshold: float = 0.0 null_score_diff_threshold: float = 0.0
model: encoders.TransformerEncoderConfig = ( model: ModelConfig = ModelConfig()
encoders.TransformerEncoderConfig())
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
...@@ -81,12 +88,12 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -81,12 +88,12 @@ class QuestionAnsweringTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model) self.task_config.model.encoder)
# Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler( return models.BertSpanLabeler(
network=encoder_network, network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.initializer_range)) stddev=self.task_config.model.encoder.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions'] start_positions = labels['start_positions']
......
...@@ -93,19 +93,18 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -93,19 +93,18 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
# Saves a checkpoint. # Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig( pretrain_cfg = bert.BertPretrainerConfig(
encoder=self._encoder_config, encoder=self._encoder_config,
num_masked_tokens=20,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence") inner_dim=10, num_classes=3, name="next_sentence")
]) ])
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg) pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
saved_path = ckpt.save(self.get_temp_dir()) saved_path = ckpt.save(self.get_temp_dir())
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
init_checkpoint=saved_path, init_checkpoint=saved_path,
model=self._encoder_config, model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
validation_data=self._get_validation_data_config( validation_data=self._get_validation_data_config(
version_2_with_negative)) version_2_with_negative))
...@@ -113,7 +112,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -113,7 +112,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def test_task_with_fit(self): def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
model=self._encoder_config, model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
validation_data=self._get_validation_data_config()) validation_data=self._get_validation_data_config())
task = question_answering.QuestionAnsweringTask(config) task = question_answering.QuestionAnsweringTask(config)
...@@ -156,7 +155,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -156,7 +155,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
model=self._encoder_config, model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
validation_data=self._get_validation_data_config()) validation_data=self._get_validation_data_config())
self._run_task(config) self._run_task(config)
......
...@@ -23,12 +23,23 @@ import tensorflow as tf ...@@ -23,12 +23,23 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A classifier/regressor configuration."""
num_classes: int = 0
use_encoder_pooler: bool = False
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig): class SentencePredictionConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
...@@ -38,15 +49,8 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -38,15 +49,8 @@ class SentencePredictionConfig(cfg.TaskConfig):
init_cls_pooler: bool = False init_cls_pooler: bool = False
hub_module_url: str = '' hub_module_url: str = ''
metric_type: str = 'accuracy' metric_type: str = 'accuracy'
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig( # Defines the concrete model config at instantiation time.
num_masked_tokens=0, # No masked language modeling head. model: ModelConfig = ModelConfig()
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=3,
dropout_rate=0.1,
name='sentence_prediction')
])
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
...@@ -68,17 +72,22 @@ class SentencePredictionTask(base_task.Task): ...@@ -68,17 +72,22 @@ class SentencePredictionTask(base_task.Task):
def build_model(self): def build_model(self):
if self._hub_module: if self._hub_module:
encoder_from_hub = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
return bert.instantiate_bertpretrainer_from_cfg(
self.task_config.model, encoder_network=encoder_from_hub)
else: else:
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model) encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder)
# Currently, we only supports bert-style sentence prediction finetuning.
return models.BertClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, labels, tf.cast(model_outputs, tf.float32), from_logits=True)
tf.cast(model_outputs['sentence_prediction'], tf.float32),
from_logits=True)
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -112,10 +121,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -112,10 +121,10 @@ class SentencePredictionTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs['sentence_prediction']) metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction']) compiled_metrics.update_state(labels, model_outputs)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
if self.metric_type == 'accuracy': if self.metric_type == 'accuracy':
...@@ -129,15 +138,13 @@ class SentencePredictionTask(base_task.Task): ...@@ -129,15 +138,13 @@ class SentencePredictionTask(base_task.Task):
if self.metric_type == 'matthews_corrcoef': if self.metric_type == 'matthews_corrcoef':
logs.update({ logs.update({
'sentence_prediction': 'sentence_prediction':
tf.expand_dims( tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0),
tf.math.argmax(outputs['sentence_prediction'], axis=1),
axis=0),
'labels': 'labels':
labels, labels,
}) })
if self.metric_type == 'pearson_spearman_corr': if self.metric_type == 'pearson_spearman_corr':
logs.update({ logs.update({
'sentence_prediction': outputs['sentence_prediction'], 'sentence_prediction': outputs,
'labels': labels, 'labels': labels,
}) })
return logs return logs
......
...@@ -37,16 +37,10 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -37,16 +37,10 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
input_path="dummy", seq_length=128, global_batch_size=1)) input_path="dummy", seq_length=128, global_batch_size=1))
def get_model_config(self, num_classes): def get_model_config(self, num_classes):
return bert.BertPretrainerConfig( return sentence_prediction.ModelConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1), vocab_size=30522, num_layers=1),
num_masked_tokens=0, num_classes=num_classes)
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10,
num_classes=num_classes,
name="sentence_prediction")
])
def _run_task(self, config): def _run_task(self, config):
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
...@@ -81,12 +75,11 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,12 +75,11 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
pretrain_cfg = bert.BertPretrainerConfig( pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1), vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence") inner_dim=10, num_classes=3, name="next_sentence")
]) ])
pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg) pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint) ckpt.save(config.init_checkpoint)
......
...@@ -25,6 +25,7 @@ import tensorflow as tf ...@@ -25,6 +25,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -32,14 +33,22 @@ from official.nlp.modeling import models ...@@ -32,14 +33,22 @@ from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
head_dropout: float = 0.1
head_initializer_range: float = 0.02
@dataclasses.dataclass @dataclasses.dataclass
class TaggingConfig(cfg.TaskConfig): class TaggingConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified. # At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
model: encoders.TransformerEncoderConfig = ( model: ModelConfig = ModelConfig()
encoders.TransformerEncoderConfig())
# The real class names, the order of which should match real label id. # The real class names, the order of which should match real label id.
# Note that a word may be tokenized into multiple word_pieces tokens, and # Note that a word may be tokenized into multiple word_pieces tokens, and
...@@ -93,14 +102,14 @@ class TaggingTask(base_task.Task): ...@@ -93,14 +102,14 @@ class TaggingTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model) self.task_config.model.encoder)
return models.BertTokenClassifier( return models.BertTokenClassifier(
network=encoder_network, network=encoder_network,
num_classes=len(self.task_config.class_names), num_classes=len(self.task_config.class_names),
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.initializer_range), stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.dropout_rate, dropout_rate=self.task_config.model.head_dropout,
output='logits') output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
......
...@@ -56,7 +56,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -56,7 +56,7 @@ class TaggingTest(tf.test.TestCase):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
init_checkpoint=saved_path, init_checkpoint=saved_path,
model=self._encoder_config, model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"]) class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config) task = tagging.TaggingTask(config)
...@@ -72,7 +72,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -72,7 +72,7 @@ class TaggingTest(tf.test.TestCase):
def test_task_with_fit(self): def test_task_with_fit(self):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
model=self._encoder_config, model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"]) class_names=["O", "B-PER", "I-PER"])
...@@ -115,14 +115,13 @@ class TaggingTest(tf.test.TestCase): ...@@ -115,14 +115,13 @@ class TaggingTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
model=self._encoder_config,
class_names=["O", "B-PER", "I-PER"], class_names=["O", "B-PER", "I-PER"],
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
def test_seqeval_metrics(self): def test_seqeval_metrics(self):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
model=self._encoder_config, model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config, train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"]) class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config) task = tagging.TaggingTask(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment