Commit 51a7d02b authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 317962481
parent 82c91e31
...@@ -193,7 +193,7 @@ class TrainerConfig(base_config.Config): ...@@ -193,7 +193,7 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
network: base_config.Config = None model: base_config.Config = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
......
...@@ -27,7 +27,7 @@ from official.nlp.modeling import losses as loss_lib ...@@ -27,7 +27,7 @@ from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass @dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig): class MaskedLMConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[ model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
]) ])
...@@ -40,7 +40,7 @@ class MaskedLMTask(base_task.Task): ...@@ -40,7 +40,7 @@ class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
def build_model(self): def build_model(self):
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network) return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
def build_losses(self, def build_losses(self,
labels, labels,
......
...@@ -26,7 +26,7 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = masked_lm.MaskedLMConfig( config = masked_lm.MaskedLMConfig(
network=bert.BertPretrainerConfig( model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20, num_masked_tokens=20,
cls_heads=[ cls_heads=[
......
...@@ -33,7 +33,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -33,7 +33,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can be specified. # At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
network: encoders.TransformerEncoderConfig = ( model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
...@@ -61,12 +61,12 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -61,12 +61,12 @@ class QuestionAnsweringTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.network) self.task_config.model)
return models.BertSpanLabeler( return models.BertSpanLabeler(
network=encoder_network, network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.network.initializer_range)) stddev=self.task_config.model.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions'] start_positions = labels['start_positions']
......
...@@ -64,7 +64,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase): ...@@ -64,7 +64,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
init_checkpoint=saved_path, init_checkpoint=saved_path,
network=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config)
task = question_answering.QuestionAnsweringTask(config) task = question_answering.QuestionAnsweringTask(config)
model = task.build_model() model = task.build_model()
...@@ -79,7 +79,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase): ...@@ -79,7 +79,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
def test_task_with_fit(self): def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
network=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config)
task = question_answering.QuestionAnsweringTask(config) task = question_answering.QuestionAnsweringTask(config)
model = task.build_model() model = task.build_model()
...@@ -121,7 +121,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase): ...@@ -121,7 +121,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
network=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
......
...@@ -38,7 +38,7 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -38,7 +38,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
metric_type: str = 'accuracy' metric_type: str = 'accuracy'
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig( model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, # No masked language modeling head. num_masked_tokens=0, # No masked language modeling head.
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
...@@ -70,9 +70,9 @@ class SentencePredictionTask(base_task.Task): ...@@ -70,9 +70,9 @@ class SentencePredictionTask(base_task.Task):
if self._hub_module: if self._hub_module:
encoder_from_hub = utils.get_encoder_from_hub(self._hub_module) encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
return bert.instantiate_bertpretrainer_from_cfg( return bert.instantiate_bertpretrainer_from_cfg(
self.task_config.network, encoder_network=encoder_from_hub) self.task_config.model, encoder_network=encoder_from_hub)
else: else:
return bert.instantiate_bertpretrainer_from_cfg(self.task_config.network) return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
......
...@@ -34,7 +34,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -34,7 +34,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
self._train_data_config = bert.SentencePredictionDataConfig( self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1) input_path="dummy", seq_length=128, global_batch_size=1)
def get_network_config(self, num_classes): def get_model_config(self, num_classes):
return bert.BertPretrainerConfig( return bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig( encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1), vocab_size=30522, num_layers=1),
...@@ -63,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -63,7 +63,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def test_task(self): def test_task(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
network=self.get_network_config(2), model=self.get_model_config(2),
train_data=self._train_data_config) train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model() model = task.build_model()
...@@ -96,7 +96,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -96,7 +96,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
metric_type=metric_type, metric_type=metric_type,
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
network=self.get_network_config(num_classes), model=self.get_model_config(num_classes),
train_data=self._train_data_config) train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model() model = task.build_model()
...@@ -115,7 +115,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -115,7 +115,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def test_task_with_fit(self): def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
network=self.get_network_config(2), train_data=self._train_data_config) model=self.get_model_config(2), train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config) task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model() model = task.build_model()
model = task.compile_model( model = task.compile_model(
...@@ -154,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -154,7 +154,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
network=self.get_network_config(2), model=self.get_model_config(2),
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
......
...@@ -33,7 +33,7 @@ class TaggingConfig(cfg.TaskConfig): ...@@ -33,7 +33,7 @@ class TaggingConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can be specified. # At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
network: encoders.TransformerEncoderConfig = ( model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
num_classes: int = 0 num_classes: int = 0
# The ignored label id will not contribute to loss. # The ignored label id will not contribute to loss.
...@@ -67,14 +67,14 @@ class TaggingTask(base_task.Task): ...@@ -67,14 +67,14 @@ class TaggingTask(base_task.Task):
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.instantiate_encoder_from_cfg( encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.network) self.task_config.model)
return models.BertTokenClassifier( return models.BertTokenClassifier(
network=encoder_network, network=encoder_network,
num_classes=self.task_config.num_classes, num_classes=self.task_config.num_classes,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.network.initializer_range), stddev=self.task_config.model.initializer_range),
dropout_rate=self.task_config.network.dropout_rate, dropout_rate=self.task_config.model.dropout_rate,
output='logits') output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
......
...@@ -56,7 +56,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -56,7 +56,7 @@ class TaggingTest(tf.test.TestCase):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
init_checkpoint=saved_path, init_checkpoint=saved_path,
network=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config, train_data=self._train_data_config,
num_classes=3) num_classes=3)
task = tagging.TaggingTask(config) task = tagging.TaggingTask(config)
...@@ -72,7 +72,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -72,7 +72,7 @@ class TaggingTest(tf.test.TestCase):
def test_task_with_fit(self): def test_task_with_fit(self):
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
network=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config, train_data=self._train_data_config,
num_classes=3) num_classes=3)
...@@ -115,7 +115,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -115,7 +115,7 @@ class TaggingTest(tf.test.TestCase):
hub_module_url = self._export_bert_tfhub() hub_module_url = self._export_bert_tfhub()
config = tagging.TaggingConfig( config = tagging.TaggingConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
network=self._encoder_config, model=self._encoder_config,
num_classes=4, num_classes=4,
train_data=self._train_data_config) train_data=self._train_data_config)
self._run_task(config) self._run_task(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment