Commit db39ef82 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 318755856
parent 997eaa19
...@@ -126,10 +126,17 @@ class QADataConfig(cfg.DataConfig): ...@@ -126,10 +126,17 @@ class QADataConfig(cfg.DataConfig):
class QADevDataConfig(cfg.DataConfig): class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering).""" """Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = "" input_path: str = ""
input_preprocessed_data_path: str = ""
version_2_with_negative: bool = False
doc_stride: int = 128
global_batch_size: int = 48 global_batch_size: int = 48
is_training: bool = False is_training: bool = False
seq_length: int = 384 seq_length: int = 384
query_length: int = 64
drop_remainder: bool = False drop_remainder: bool = False
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Question answering task.""" """Question answering task."""
import logging import collections
import json
import os
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
...@@ -22,7 +25,12 @@ import tensorflow_hub as hub ...@@ -22,7 +25,12 @@ import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.bert import squad_evaluate_v2_0
from official.nlp.bert import tokenization
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import squad_lib as squad_lib_wp
from official.nlp.data import squad_lib_sp
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
...@@ -33,6 +41,9 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -33,6 +41,9 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can be specified. # At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
n_best_size: int = 20
max_answer_length: int = 30
null_score_diff_threshold: float = 0.0
model: encoders.TransformerEncoderConfig = ( model: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
...@@ -41,10 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -41,10 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
@base_task.register_task_cls(QuestionAnsweringConfig) @base_task.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task): class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering. """Task object for question answering."""
TODO(lehou): Add post-processing.
"""
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig):
super(QuestionAnsweringTask, self).__init__(params) super(QuestionAnsweringTask, self).__init__(params)
...@@ -56,6 +64,14 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -56,6 +64,14 @@ class QuestionAnsweringTask(base_task.Task):
else: else:
self._hub_module = None self._hub_module = None
if params.validation_data.tokenization == 'WordPiece':
self.squad_lib = squad_lib_wp
elif params.validation_data.tokenization == 'SentencePiece':
self.squad_lib = squad_lib_sp
else:
raise ValueError('Unsupported tokenization method: {}'.format(
params.validation_data.tokenization))
def build_model(self): def build_model(self):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
...@@ -85,9 +101,53 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -85,9 +101,53 @@ class QuestionAnsweringTask(base_task.Task):
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2 loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return loss return loss
def _preprocess_eval_data(self, params):
eval_examples = self.squad_lib.read_squad_examples(
input_file=params.input_path,
is_training=False,
version_2_with_negative=params.version_2_with_negative)
temp_file_path = params.input_preprocessed_data_path or '/tmp'
eval_writer = self.squad_lib.FeatureWriter(
filename=os.path.join(temp_file_path, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
kwargs = dict(
examples=eval_examples,
tokenizer=tokenization.FullTokenizer(
vocab_file=params.vocab_file,
do_lower_case=params.do_lower_case),
max_seq_length=params.seq_length,
doc_stride=params.doc_stride,
max_query_length=params.query_length,
is_training=False,
output_fn=_append_feature,
batch_size=params.global_batch_size)
if params.tokenization == 'SentencePiece':
# squad_lib_sp requires one more argument 'do_lower_case'.
kwargs['do_lower_case'] = params.do_lower_case
eval_dataset_size = self.squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Evaluation input stats *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', params.global_batch_size)
logging.info(' Dataset size = %d', eval_dataset_size)
return eval_writer.filename, eval_examples, eval_features
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy': if params.input_path == 'dummy':
# Dummy training data for unit test.
def dummy_data(_): def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict( x = dict(
...@@ -105,11 +165,17 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -105,11 +165,17 @@ class QuestionAnsweringTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
if params.is_training:
input_path = params.input_path
else:
input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params))
batch_size = input_context.get_per_replica_batch_size( batch_size = input_context.get_per_replica_batch_size(
params.global_batch_size) if input_context else params.global_batch_size params.global_batch_size) if input_context else params.global_batch_size
# TODO(chendouble): add and use nlp.data.question_answering_dataloader. # TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset = input_pipeline.create_squad_dataset( dataset = input_pipeline.create_squad_dataset(
params.input_path, input_path,
params.seq_length, params.seq_length,
batch_size, batch_size,
is_training=params.is_training, is_training=params.is_training,
...@@ -141,6 +207,70 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -141,6 +207,70 @@ class QuestionAnsweringTask(base_task.Task):
y_true=labels, # labels has keys 'start_positions' and 'end_positions'. y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={'start_positions': start_logits, 'end_positions': end_logits}) y_pred={'start_positions': start_logits, 'end_positions': end_logits})
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs
unique_ids = features.pop('unique_ids')
model_outputs = self.inference_step(features, model)
start_logits, end_logits = model_outputs
logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids,
'start_logits': start_logits,
'end_logits': end_logits,
}
return logs
raw_aggregated_result = collections.namedtuple(
'RawResult', ['unique_id', 'start_logits', 'end_logits'])
def aggregate_logs(self, state=None, step_outputs=None):
assert step_outputs is not None, 'Got no logs from self.validation_step.'
if state is None:
state = []
for unique_ids, start_logits, end_logits in zip(
step_outputs['unique_ids'],
step_outputs['start_logits'],
step_outputs['end_logits']):
u_ids, s_logits, e_logits = (
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy())
if u_ids.size == 1:
u_ids = [u_ids]
s_logits = [s_logits]
e_logits = [e_logits]
for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist()))
return state
def reduce_aggregated_logs(self, aggregated_logs):
all_predictions, _, scores_diff = (
self.squad_lib.postprocess_output(
self._eval_examples,
self._eval_features,
aggregated_logs,
self.task_config.n_best_size,
self.task_config.max_answer_length,
self.task_config.validation_data.do_lower_case,
version_2_with_negative=(
self.task_config.validation_data.version_2_with_negative),
null_score_diff_threshold=(
self.task_config.null_score_diff_threshold),
verbose=False))
with tf.io.gfile.GFile(
self.task_config.validation_data.input_path, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate(
pred_dataset, all_predictions, scores_diff)
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for official.nlp.tasks.question_answering.""" """Tests for official.nlp.tasks.question_answering."""
import functools import itertools
import json
import os import os
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -25,30 +27,67 @@ from official.nlp.configs import encoders ...@@ -25,30 +27,67 @@ from official.nlp.configs import encoders
from official.nlp.tasks import question_answering from official.nlp.tasks import question_answering
class QuestionAnsweringTaskTest(tf.test.TestCase): class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(QuestionAnsweringTaskTest, self).setUp() super(QuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig( self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1) vocab_size=30522, num_layers=1)
self._train_data_config = bert.QADataConfig( self._train_data_config = bert.QADataConfig(
input_path="dummy", seq_length=128, global_batch_size=1) input_path="dummy",
seq_length=128,
global_batch_size=1)
val_data = {"version": "1.1",
"data": [{"paragraphs": [
{"context": "Sky is blue.",
"qas": [{"question": "What is blue?", "id": "1234",
"answers": [{"text": "Sky", "answer_start": 0},
{"text": "Sky", "answer_start": 0},
{"text": "Sky", "answer_start": 0}]
}]}]}]}
self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json")
with tf.io.gfile.GFile(self._val_input_path, "w") as writer:
writer.write(json.dumps(val_data, indent=4) + "\n")
self._test_vocab = os.path.join(self.get_temp_dir(), "vocab.txt")
with tf.io.gfile.GFile(self._test_vocab, "w") as writer:
writer.write("[PAD]\n[UNK]\n[CLS]\n[SEP]\n[MASK]\nsky\nis\nblue\n")
def _get_validation_data_config(self, version_2_with_negative=False):
return bert.QADevDataConfig(
input_path=self._val_input_path,
input_preprocessed_data_path=self.get_temp_dir(),
seq_length=128,
global_batch_size=1,
version_2_with_negative=version_2_with_negative,
vocab_file=self._test_vocab,
tokenization="WordPiece",
do_lower_case=True)
def _run_task(self, config): def _run_task(self, config):
task = question_answering.QuestionAnsweringTask(config) task = question_answering.QuestionAnsweringTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
task.initialize(model)
strategy = tf.distribute.get_strategy() train_dataset = task.build_inputs(config.train_data)
dataset = strategy.experimental_distribute_datasets_from_function( train_iterator = iter(train_dataset)
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(train_iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
val_dataset = task.build_inputs(config.validation_data)
def test_task(self): val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics)
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
@parameterized.parameters(itertools.product(
(False, True),
("WordPiece", "SentencePiece"),
))
def test_task(self, version_2_with_negative, tokenization):
# Saves a checkpoint. # Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig( pretrain_cfg = bert.BertPretrainerConfig(
encoder=self._encoder_config, encoder=self._encoder_config,
...@@ -65,22 +104,16 @@ class QuestionAnsweringTaskTest(tf.test.TestCase): ...@@ -65,22 +104,16 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
init_checkpoint=saved_path, init_checkpoint=saved_path,
model=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config,
task = question_answering.QuestionAnsweringTask(config) validation_data=self._get_validation_data_config(
model = task.build_model() version_2_with_negative))
metrics = task.build_metrics() self._run_task(config)
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
task.initialize(model)
def test_task_with_fit(self): def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
model=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
task = question_answering.QuestionAnsweringTask(config) task = question_answering.QuestionAnsweringTask(config)
model = task.build_model() model = task.build_model()
model = task.compile_model( model = task.compile_model(
...@@ -122,7 +155,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase): ...@@ -122,7 +155,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
config = question_answering.QuestionAnsweringConfig( config = question_answering.QuestionAnsweringConfig(
hub_module_url=hub_module_url, hub_module_url=hub_module_url,
model=self._encoder_config, model=self._encoder_config,
train_data=self._train_data_config) train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
self._run_task(config) self._run_task(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment