Commit 29d45e88 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325073831
parent 36748648
......@@ -20,13 +20,9 @@ Includes configurations and instantiation methods.
from typing import List, Optional, Text
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import bert_pretrainer
@dataclasses.dataclass
......@@ -40,32 +36,9 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass
class BertPretrainerConfig(base_config.Config):
"""BERT encoder configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
class PretrainerConfig(base_config.Config):
"""Pretrainer configuration."""
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
] if cls_head_configs else []
def instantiate_pretrainer_from_cfg(
config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder
if encoder_network is None:
encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
return bert_pretrainer.BertPretrainerV2(
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
mlm_activation: str = "gelu"
mlm_initializer_range: float = 0.02
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence"),
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = bert.instantiate_pretrainer_from_cfg(config)
def test_checkpoint_items(self):
config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_pretrainer_from_cfg(config)
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
tf.test.main()
......@@ -14,21 +14,17 @@
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional
from typing import List
import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer
@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
class ElectraPretrainerConfig(base_config.Config):
"""ELECTRA pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
......@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config):
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
generator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
discriminator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
if cls_head_configs:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
]
else:
return []
def instantiate_pretrainer_from_cfg(
config: ELECTRAPretrainerConfig,
generator_network: Optional[tf.keras.Model] = None,
discriminator_network: Optional[tf.keras.Model] = None,
) -> electra_pretrainer.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads),
disallow_correct=config.disallow_correct)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
class ELECTRAModelsTest(tf.test.TestCase):
def test_network_invocation(self):
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
)
_ = electra.instantiate_pretrainer_from_cfg(config)
# Invokes with classification heads.
config = electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=10, num_layers=2),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
])
_ = electra.instantiate_pretrainer_from_cfg(config)
if __name__ == "__main__":
tf.test.main()
......@@ -15,20 +15,23 @@
# ==============================================================================
"""Transformer Encoders.
Includes configurations and instantiation methods.
Includes configurations and factory methods.
"""
from typing import Optional
from absl import logging
import dataclasses
import gin
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass
class TransformerEncoderConfig(base_config.Config):
class BertEncoderConfig(hyperparams.Config):
"""BERT encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
......@@ -44,55 +47,86 @@ class TransformerEncoderConfig(base_config.Config):
embedding_size: Optional[int] = None
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
type: Optional[str] = "bert"
bert: BertEncoderConfig = BertEncoderConfig()
ENCODER_CLS = {
"bert": networks.TransformerEncoder,
}
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type = config.type
encoder_cfg = config.get()
encoder_cls = encoder_cls or ENCODER_CLS[encoder_type]
logging.info("Encoder class: %s to build...", encoder_cls.__name__)
if bypass_config:
return encoder_cls()
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
max_seq_length=config.max_position_embeddings,
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
stddev=encoder_cfg.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
embedding_width=config.embedding_size,
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
......@@ -14,7 +14,7 @@
# ==============================================================================
"""Models package definition."""
from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_pretrainer import *
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
......@@ -19,16 +19,20 @@ import tensorflow as tf
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig):
class ElectraPretrainConfig(cfg.TaskConfig):
"""The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig(
model: electra.ElectraPretrainerConfig = electra.ElectraPretrainerConfig(
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
......@@ -40,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(ELECTRAPretrainConfig)
class ELECTRAPretrainTask(base_task.Task):
def _build_pretrainer(
config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
discriminator_network = encoders.build_encoder(discriminator_encoder_cfg)
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.build_encoder(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.build_encoder(generator_encoder_cfg)
generator_encoder_cfg = generator_encoder_cfg.get()
return models.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=generator_encoder_cfg.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=[
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
],
disallow_correct=config.disallow_correct)
@task_factory.register_task_cls(ElectraPretrainConfig)
class ElectraPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self):
return electra.instantiate_pretrainer_from_cfg(
self.task_config.model)
return _build_pretrainer(self.task_config.model)
def build_losses(self,
labels,
......@@ -70,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task):
sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
sentence_labels, sentence_outputs, from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
......
......@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase):
class ElectraPretrainTaskTest(tf.test.TestCase):
def test_task(self):
config = electra_task.ELECTRAPretrainConfig(
model=electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
config = electra_task.ElectraPretrainConfig(
model=electra.ElectraPretrainerConfig(
generator_encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
discriminator_encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_masked_tokens=20,
sequence_length=128,
cls_heads=[
......@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase):
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config)
task = electra_task.ElectraPretrainTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
......
......@@ -19,15 +19,19 @@ import tensorflow as tf
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
"""The model config."""
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
])
......@@ -37,11 +41,21 @@ class MaskedLMConfig(cfg.TaskConfig):
@task_factory.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
"""Task object for Mask language modeling."""
def build_model(self, params=None):
params = params or self.task_config.model
return bert.instantiate_pretrainer_from_cfg(params)
config = params or self.task_config.model
encoder_cfg = config.encoder
encoder_network = encoders.build_encoder(encoder_cfg)
cls_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
] if config.cls_heads else []
return models.BertPretrainerV2(
mlm_activation=tf_utils.get_activation(config.mlm_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.mlm_initializer_range),
encoder_network=encoder_network,
classification_heads=cls_heads)
def build_losses(self,
labels,
......@@ -63,9 +77,8 @@ class MaskedLMTask(base_task.Task):
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(sentence_labels,
sentence_outputs,
from_logits=True))
tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_outputs, from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
......
......@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
model=bert.PretrainerConfig(
encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
......
......@@ -42,8 +42,7 @@ from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
@dataclasses.dataclass
......@@ -94,13 +93,13 @@ class QuestionAnsweringTask(base_task.Task):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder)
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler(
network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range))
stddev=encoder_cfg.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
......
......@@ -25,6 +25,7 @@ from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import question_answering_dataloader
from official.nlp.tasks import masked_lm
from official.nlp.tasks import question_answering
......@@ -32,21 +33,37 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(QuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1)
self._encoder_config = encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = question_answering_dataloader.QADataConfig(
input_path="dummy",
seq_length=128,
global_batch_size=1)
val_data = {"version": "1.1",
"data": [{"paragraphs": [
{"context": "Sky is blue.",
"qas": [{"question": "What is blue?", "id": "1234",
"answers": [{"text": "Sky", "answer_start": 0},
{"text": "Sky", "answer_start": 0},
{"text": "Sky", "answer_start": 0}]
}]}]}]}
input_path="dummy", seq_length=128, global_batch_size=1)
val_data = {
"version":
"1.1",
"data": [{
"paragraphs": [{
"context":
"Sky is blue.",
"qas": [{
"question":
"What is blue?",
"id":
"1234",
"answers": [{
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}]
}]
}]
}]
}
self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json")
with tf.io.gfile.GFile(self._val_input_path, "w") as writer:
writer.write(json.dumps(val_data, indent=4) + "\n")
......@@ -87,19 +104,20 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
@parameterized.parameters(itertools.product(
(False, True),
("WordPiece", "SentencePiece"),
))
@parameterized.parameters(
itertools.product(
(False, True),
("WordPiece", "SentencePiece"),
))
def test_task(self, version_2_with_negative, tokenization):
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
pretrain_cfg = bert.PretrainerConfig(
encoder=self._encoder_config,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
saved_path = ckpt.save(self.get_temp_dir())
......
......@@ -44,8 +44,7 @@ class ModelConfig(base_config.Config):
"""A classifier/regressor configuration."""
num_classes: int = 0
use_encoder_pooler: bool = False
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
@dataclasses.dataclass
......@@ -85,15 +84,14 @@ class SentencePredictionTask(base_task.Task):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder)
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only support bert-style sentence prediction finetuning.
return models.BertClassifier(
network=encoder_network,
num_classes=self.task_config.model.num_classes,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.encoder.initializer_range),
stddev=encoder_cfg.initializer_range),
use_encoder_pooler=self.task_config.model.use_encoder_pooler)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
......
......@@ -26,6 +26,7 @@ from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction
......@@ -68,8 +69,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def get_model_config(self, num_classes):
return sentence_prediction.ModelConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
num_classes=num_classes)
def _run_task(self, config):
......@@ -102,14 +103,14 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
......@@ -136,8 +137,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
if num_classes == 1:
self.assertIsInstance(metrics[0], tf.keras.metrics.MeanSquaredError)
else:
self.assertIsInstance(
metrics[0], tf.keras.metrics.SparseCategoricalAccuracy)
self.assertIsInstance(metrics[0],
tf.keras.metrics.SparseCategoricalAccuracy)
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
......
......@@ -37,8 +37,7 @@ from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A base span labeler configuration."""
encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
head_dropout: float = 0.1
head_initializer_range: float = 0.02
......@@ -102,8 +101,7 @@ class TaggingTask(base_task.Task):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.model.encoder)
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
return models.BertTokenClassifier(
network=encoder_network,
......
......@@ -53,8 +53,8 @@ class TaggingTest(tf.test.TestCase):
def setUp(self):
super(TaggingTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1)
self._encoder_config = encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = tagging_data_loader.TaggingDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
......@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
def test_task(self):
# Saves a checkpoint.
encoder = encoders.instantiate_encoder_from_cfg(self._encoder_config)
encoder = encoders.build_encoder(self._encoder_config)
ckpt = tf.train.Checkpoint(encoder=encoder)
saved_path = ckpt.save(self.get_temp_dir())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment