Commit 29d45e88 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325073831
parent 36748648
...@@ -20,13 +20,9 @@ Includes configurations and instantiation methods. ...@@ -20,13 +20,9 @@ Includes configurations and instantiation methods.
from typing import List, Optional, Text from typing import List, Optional, Text
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer
@dataclasses.dataclass @dataclasses.dataclass
...@@ -40,32 +36,9 @@ class ClsHeadConfig(base_config.Config): ...@@ -40,32 +36,9 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainerConfig(base_config.Config): class PretrainerConfig(base_config.Config):
"""BERT encoder configuration.""" """Pretrainer configuration."""
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig())
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
mlm_activation: str = "gelu"
mlm_initializer_range: float = 0.02
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
] if cls_head_configs else []
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence"),
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
tf.test.main()
...@@ -14,21 +14,17 @@ ...@@ -14,21 +14,17 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ELECTRA model configurations and instantiation methods.""" """ELECTRA model configurations and instantiation methods."""
from typing import List, Optional from typing import List
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass @dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config): class ElectraPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration.""" """ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76 num_masked_tokens: int = 76
sequence_length: int = 512 sequence_length: int = 512
...@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config): ...@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config):
discriminator_loss_weight: float = 50.0 discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True tie_embeddings: bool = True
disallow_correct: bool = False disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = ( generator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig()) discriminator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
...@@ -15,20 +15,23 @@ ...@@ -15,20 +15,23 @@
# ============================================================================== # ==============================================================================
"""Transformer Encoders. """Transformer Encoders.
Includes configurations and instantiation methods. Includes configurations and factory methods.
""" """
from typing import Optional from typing import Optional
from absl import logging
import dataclasses import dataclasses
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
class TransformerEncoderConfig(base_config.Config): class BertEncoderConfig(hyperparams.Config):
"""BERT encoder configuration.""" """BERT encoder configuration."""
vocab_size: int = 30522 vocab_size: int = 30522
hidden_size: int = 768 hidden_size: int = 768
...@@ -44,55 +47,86 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -44,55 +47,86 @@ class TransformerEncoderConfig(base_config.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg( @dataclasses.dataclass
config: TransformerEncoderConfig, class EncoderConfig(hyperparams.OneOfConfig):
encoder_cls=networks.TransformerEncoder, """Encoder configuration."""
embedding_layer: Optional[layers.OnDeviceEmbedding] = None): type: Optional[str] = "bert"
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" bert: BertEncoderConfig = BertEncoderConfig()
ENCODER_CLS = {
"bert": networks.TransformerEncoder,
}
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config:
return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold": if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=config.vocab_size, vocab_size=encoder_cfg.vocab_size,
type_vocab_size=config.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=config.hidden_size, hidden_size=encoder_cfg.hidden_size,
max_seq_length=config.max_position_embeddings, max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
) )
hidden_cfg = dict( hidden_cfg = dict(
num_attention_heads=config.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=config.intermediate_size, intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation( intermediate_activation=tf_utils.get_activation(
config.hidden_activation), encoder_cfg.hidden_activation),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
) )
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers, num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=config.hidden_size, pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=encoder_cfg.initializer_range))
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder": # Uses the default BERTEncoder configuration schema to create the encoder.
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls)) # If it does not match, please add a switch branch by the encoder type.
encoder_network = encoder_cls( return encoder_cls(
vocab_size=config.vocab_size, vocab_size=encoder_cfg.vocab_size,
hidden_size=config.hidden_size, hidden_size=encoder_cfg.hidden_size,
num_layers=config.num_layers, num_layers=encoder_cfg.num_layers,
num_attention_heads=config.num_attention_heads, num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=config.intermediate_size, intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation), activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=config.dropout_rate, dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings, max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=config.embedding_size, embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer) embedding_layer=embedding_layer)
return encoder_network
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Models package definition.""" """Models package definition."""
from official.nlp.modeling.models.bert_classifier import BertClassifier from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
...@@ -19,16 +19,20 @@ import tensorflow as tf ...@@ -19,16 +19,20 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import electra from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass @dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig): class ElectraPretrainConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig( model: electra.ElectraPretrainerConfig = electra.ElectraPretrainerConfig(
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, inner_dim=768,
...@@ -40,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig): ...@@ -40,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(ELECTRAPretrainConfig) def _build_pretrainer(
class ELECTRAPretrainTask(base_task.Task): config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
discriminator_network = encoders.build_encoder(discriminator_encoder_cfg)
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.build_encoder(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.build_encoder(generator_encoder_cfg)
generator_encoder_cfg = generator_encoder_cfg.get()
return models.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=generator_encoder_cfg.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=[
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
],
disallow_correct=config.disallow_correct)
@task_factory.register_task_cls(ElectraPretrainConfig)
class ElectraPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection).""" """ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self): def build_model(self):
return electra.instantiate_pretrainer_from_cfg( return _build_pretrainer(self.task_config.model)
self.task_config.model)
def build_losses(self, def build_losses(self,
labels, labels,
...@@ -70,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task): ...@@ -70,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task):
sentence_outputs = tf.cast( sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32) model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy( sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_labels, sentence_outputs, from_logits=True)
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
...@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader ...@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase): class ElectraPretrainTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = electra_task.ELECTRAPretrainConfig( config = electra_task.ElectraPretrainConfig(
model=electra.ELECTRAPretrainerConfig( model=electra.ElectraPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig( generator_encoder=encoders.EncoderConfig(
vocab_size=30522, num_layers=1), bert=encoders.BertEncoderConfig(vocab_size=30522,
discriminator_encoder=encoders.TransformerEncoderConfig( num_layers=1)),
vocab_size=30522, num_layers=1), discriminator_encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_masked_tokens=20, num_masked_tokens=20,
sequence_length=128, sequence_length=128,
cls_heads=[ cls_heads=[
...@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase): ...@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase):
max_predictions_per_seq=20, max_predictions_per_seq=20,
seq_length=128, seq_length=128,
global_batch_size=1)) global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config) task = electra_task.ElectraPretrainTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
......
...@@ -19,15 +19,19 @@ import tensorflow as tf ...@@ -19,15 +19,19 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass @dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig): class MaskedLMConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[ model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
]) ])
...@@ -37,11 +41,21 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -37,11 +41,21 @@ class MaskedLMConfig(cfg.TaskConfig):
@task_factory.register_task_cls(MaskedLMConfig) @task_factory.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Task object for Mask language modeling."""
def build_model(self, params=None): def build_model(self, params=None):
params = params or self.task_config.model config = params or self.task_config.model
return bert.instantiate_pretrainer_from_cfg(params) encoder_cfg = config.encoder
encoder_network = encoders.build_encoder(encoder_cfg)
cls_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
] if config.cls_heads else []
return models.BertPretrainerV2(
mlm_activation=tf_utils.get_activation(config.mlm_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.mlm_initializer_range),
encoder_network=encoder_network,
classification_heads=cls_heads)
def build_losses(self, def build_losses(self,
labels, labels,
...@@ -63,9 +77,8 @@ class MaskedLMTask(base_task.Task): ...@@ -63,9 +77,8 @@ class MaskedLMTask(base_task.Task):
sentence_outputs = tf.cast( sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32) model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean( sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(sentence_labels, tf.keras.losses.sparse_categorical_crossentropy(
sentence_outputs, sentence_labels, sentence_outputs, from_logits=True))
from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
...@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = masked_lm.MaskedLMConfig( config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig( model=bert.PretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
......
...@@ -42,8 +42,7 @@ from official.nlp.tasks import utils ...@@ -42,8 +42,7 @@ from official.nlp.tasks import utils
@dataclasses.dataclass @dataclasses.dataclass
class ModelConfig(base_config.Config): class ModelConfig(base_config.Config):
"""A base span labeler configuration.""" """A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig())
@dataclasses.dataclass @dataclasses.dataclass
...@@ -94,13 +93,13 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -94,13 +93,13 @@ class QuestionAnsweringTask(base_task.Task):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.build_encoder(self.task_config.model.encoder)
self.task_config.model.encoder) encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only supports bert-style question answering finetuning. # Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler( return models.BertSpanLabeler(
network=encoder_network, network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range)) stddev=encoder_cfg.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions'] start_positions = labels['start_positions']
......
...@@ -25,6 +25,7 @@ from official.nlp.bert import export_tfhub ...@@ -25,6 +25,7 @@ from official.nlp.bert import export_tfhub
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import question_answering_dataloader from official.nlp.data import question_answering_dataloader
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering from official.nlp.tasks import question_answering
...@@ -32,21 +33,37 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -32,21 +33,37 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(QuestionAnsweringTaskTest, self).setUp() super(QuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig( self._encoder_config = encoders.EncoderConfig(
vocab_size=30522, num_layers=1) bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = question_answering_dataloader.QADataConfig( self._train_data_config = question_answering_dataloader.QADataConfig(
input_path="dummy", input_path="dummy", seq_length=128, global_batch_size=1)
seq_length=128,
global_batch_size=1) val_data = {
"version":
val_data = {"version": "1.1", "1.1",
"data": [{"paragraphs": [ "data": [{
{"context": "Sky is blue.", "paragraphs": [{
"qas": [{"question": "What is blue?", "id": "1234", "context":
"answers": [{"text": "Sky", "answer_start": 0}, "Sky is blue.",
{"text": "Sky", "answer_start": 0}, "qas": [{
{"text": "Sky", "answer_start": 0}] "question":
}]}]}]} "What is blue?",
"id":
"1234",
"answers": [{
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}]
}]
}]
}]
}
self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json") self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json")
with tf.io.gfile.GFile(self._val_input_path, "w") as writer: with tf.io.gfile.GFile(self._val_input_path, "w") as writer:
writer.write(json.dumps(val_data, indent=4) + "\n") writer.write(json.dumps(val_data, indent=4) + "\n")
...@@ -87,19 +104,20 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -87,19 +104,20 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
metrics = task.reduce_aggregated_logs(logs) metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics) self.assertIn("final_f1", metrics)
@parameterized.parameters(itertools.product( @parameterized.parameters(
(False, True), itertools.product(
("WordPiece", "SentencePiece"), (False, True),
)) ("WordPiece", "SentencePiece"),
))
def test_task(self, version_2_with_negative, tokenization): def test_task(self, version_2_with_negative, tokenization):
# Saves a checkpoint. # Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig( pretrain_cfg = bert.PretrainerConfig(
encoder=self._encoder_config, encoder=self._encoder_config,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence") inner_dim=10, num_classes=3, name="next_sentence")
]) ])
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg) pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
saved_path = ckpt.save(self.get_temp_dir()) saved_path = ckpt.save(self.get_temp_dir())
......
...@@ -44,8 +44,7 @@ class ModelConfig(base_config.Config): ...@@ -44,8 +44,7 @@ class ModelConfig(base_config.Config):
"""A classifier/regressor configuration.""" """A classifier/regressor configuration."""
num_classes: int = 0 num_classes: int = 0
use_encoder_pooler: bool = False use_encoder_pooler: bool = False
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig())
@dataclasses.dataclass @dataclasses.dataclass
...@@ -85,15 +84,14 @@ class SentencePredictionTask(base_task.Task): ...@@ -85,15 +84,14 @@ class SentencePredictionTask(base_task.Task):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.build_encoder(self.task_config.model.encoder)
self.task_config.model.encoder) encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only support bert-style sentence prediction finetuning. # Currently, we only support bert-style sentence prediction finetuning.
return models.BertClassifier( return models.BertClassifier(
network=encoder_network, network=encoder_network,
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range), stddev=encoder_cfg.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler) use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
......
...@@ -26,6 +26,7 @@ from official.nlp.bert import export_tfhub ...@@ -26,6 +26,7 @@ from official.nlp.bert import export_tfhub
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import sentence_prediction_dataloader from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
...@@ -68,8 +69,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -68,8 +69,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def get_model_config(self, num_classes): def get_model_config(self, num_classes):
return sentence_prediction.ModelConfig( return sentence_prediction.ModelConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.EncoderConfig(
vocab_size=30522, num_layers=1), bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
num_classes=num_classes) num_classes=num_classes)
def _run_task(self, config): def _run_task(self, config):
...@@ -102,14 +103,14 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -102,14 +103,14 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint. # Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig( pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.EncoderConfig(
vocab_size=30522, num_layers=1), bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence") inner_dim=10, num_classes=3, name="next_sentence")
]) ])
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg) pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint) ckpt.save(config.init_checkpoint)
...@@ -136,8 +137,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -136,8 +137,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
if num_classes == 1: if num_classes == 1:
self.assertIsInstance(metrics[0], tf.keras.metrics.MeanSquaredError) self.assertIsInstance(metrics[0], tf.keras.metrics.MeanSquaredError)
else: else:
self.assertIsInstance( self.assertIsInstance(metrics[0],
metrics[0], tf.keras.metrics.SparseCategoricalAccuracy) tf.keras.metrics.SparseCategoricalAccuracy)
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
iterator = iter(dataset) iterator = iter(dataset)
......
...@@ -37,8 +37,7 @@ from official.nlp.tasks import utils ...@@ -37,8 +37,7 @@ from official.nlp.tasks import utils
@dataclasses.dataclass @dataclasses.dataclass
class ModelConfig(base_config.Config): class ModelConfig(base_config.Config):
"""A base span labeler configuration.""" """A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = ( encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoders.TransformerEncoderConfig())
head_dropout: float = 0.1 head_dropout: float = 0.1
head_initializer_range: float = 0.02 head_initializer_range: float = 0.02
...@@ -102,8 +101,7 @@ class TaggingTask(base_task.Task): ...@@ -102,8 +101,7 @@ class TaggingTask(base_task.Task):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.build_encoder(self.task_config.model.encoder)
self.task_config.model.encoder)
return models.BertTokenClassifier( return models.BertTokenClassifier(
network=encoder_network, network=encoder_network,
......
...@@ -53,8 +53,8 @@ class TaggingTest(tf.test.TestCase): ...@@ -53,8 +53,8 @@ class TaggingTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(TaggingTest, self).setUp() super(TaggingTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig( self._encoder_config = encoders.EncoderConfig(
vocab_size=30522, num_layers=1) bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = tagging_data_loader.TaggingDataConfig( self._train_data_config = tagging_data_loader.TaggingDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1) input_path="dummy", seq_length=128, global_batch_size=1)
...@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
def test_task(self): def test_task(self):
# Saves a checkpoint. # Saves a checkpoint.
encoder = encoders.instantiate_encoder_from_cfg(self._encoder_config) encoder = encoders.build_encoder(self._encoder_config)
ckpt = tf.train.Checkpoint(encoder=encoder) ckpt = tf.train.Checkpoint(encoder=encoder)
saved_path = ckpt.save(self.get_temp_dir()) saved_path = ckpt.save(self.get_temp_dir())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment