Commit cc783998 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327325736
parent bf6a29e4
...@@ -97,7 +97,11 @@ class BertClassifier(tf.keras.Model): ...@@ -97,7 +97,11 @@ class BertClassifier(tf.keras.Model):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
return dict(encoder=self._network) items = dict(encoder=self._network)
if hasattr(self.classifier, 'checkpoint_items'):
for key, item in self.classifier.checkpoint_items.items():
items['.'.join([self.classifier.name, key])] = item
return items
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -215,9 +215,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -215,9 +215,8 @@ class SentencePredictionTask(base_task.Task):
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'], 'encoder': model.checkpoint_items['encoder'],
} }
# TODO(b/160251903): Investigate why no pooler dense improves finetuning
# accuracies.
if self.task_config.init_cls_pooler: if self.task_config.init_cls_pooler:
# This option is valid when use_encoder_pooler is false.
pretrain2finetune_mapping[ pretrain2finetune_mapping[
'next_sentence.pooler_dense'] = model.checkpoint_items[ 'next_sentence.pooler_dense'] = model.checkpoint_items[
'sentence_prediction.pooler_dense'] 'sentence_prediction.pooler_dense']
......
...@@ -87,34 +87,40 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -87,34 +87,40 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
def test_task(self): @parameterized.named_parameters(
config = sentence_prediction.SentencePredictionConfig( ("init_cls_pooler", True),
init_checkpoint=self.get_temp_dir(), ("init_encoder", False),
model=self.get_model_config(2), )
train_data=self._train_data_config) def test_task(self, init_cls_pooler):
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint. # Saves a checkpoint.
pretrain_cfg = bert.PretrainerConfig( pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)), bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence") inner_dim=768, num_classes=2, name="next_sentence")
]) ])
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg) pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items) model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint) init_path = ckpt.save(self.get_temp_dir())
# Creates the task.
config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=init_path,
model=self.get_model_config(num_classes=2),
train_data=self._train_data_config,
init_cls_pooler=init_cls_pooler)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.initialize(model) task.initialize(model)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment