Commit 43587c64 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316784919
parent 85cfe94d
...@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib ...@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig): class SentencePredictionConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
# At most one of `pretrain_checkpoint_dir` and `hub_module_url` can # At most one of `init_checkpoint` and `hub_module_url` can
# be specified. # be specified.
pretrain_checkpoint_dir: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig( network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, num_masked_tokens=0,
...@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task):
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig):
super(SentencePredictionTask, self).__init__(params) super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.pretrain_checkpoint_dir: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.') '`pretrain_checkpoint_dir` can be specified.')
if params.hub_module_url: if params.hub_module_url:
...@@ -82,8 +82,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -82,8 +82,8 @@ class SentencePredictionTask(base_task.Task):
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels, labels=labels,
predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'], predictions=tf.nn.log_softmax(
axis=-1)) model_outputs['sentence_prediction'], axis=-1))
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -92,6 +92,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -92,6 +92,7 @@ class SentencePredictionTask(base_task.Task):
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy': if params.input_path == 'dummy':
def dummy_data(_): def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict( x = dict(
...@@ -112,9 +113,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -112,9 +113,7 @@ class SentencePredictionTask(base_task.Task):
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
metrics = [ metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
...@@ -126,8 +125,10 @@ class SentencePredictionTask(base_task.Task): ...@@ -126,8 +125,10 @@ class SentencePredictionTask(base_task.Task):
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
pretrain_ckpt_dir = self.task_config.pretrain_checkpoint_dir ckpt_dir_or_file = self.task_config.init_checkpoint
if not pretrain_ckpt_dir: if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return return
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
...@@ -137,10 +138,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -137,10 +138,7 @@ class SentencePredictionTask(base_task.Task):
model.checkpoint_items['sentence_prediction.pooler_dense'], model.checkpoint_items['sentence_prediction.pooler_dense'],
} }
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping) ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
latest_pretrain_ckpt = tf.train.latest_checkpoint(pretrain_ckpt_dir) status = ckpt.restore(ckpt_dir_or_file)
if latest_pretrain_ckpt is None:
raise FileNotFoundError(
'Cannot find pretrain checkpoint under {}'.format(pretrain_ckpt_dir))
status = ckpt.restore(latest_pretrain_ckpt)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint.') logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
network=bert.BertPretrainerConfig( network=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0, num_masked_tokens=0,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
...@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self): def _export_bert_tfhub(self):
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
vocab_size=30522, vocab_size=30522,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment