Commit 0ab5dcbf authored by Philip Pham's avatar Philip Pham Committed by A. Unique TensorFlower
Browse files

Add TriviaQA Task to projects

PiperOrigin-RevId: 334950562
parent ec955c21
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA: A Reading Comprehension Dataset."""
import functools
import json
import os
from absl import logging
import apache_beam as beam
import six
import tensorflow as tf
import tensorflow_datasets.public_api as tfds
from official.nlp.projects.triviaqa import preprocess
_CITATION = """
@article{2017arXivtriviaqa,
author = {{Joshi}, Mandar and {Choi}, Eunsol and {Weld},
Daniel and {Zettlemoyer}, Luke},
title = "{triviaqa: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension}",
journal = {arXiv e-prints},
year = 2017,
eid = {arXiv:1705.03551},
pages = {arXiv:1705.03551},
archivePrefix = {arXiv},
eprint = {1705.03551},
}
"""
_DOWNLOAD_URL_TMPL = (
"http://nlp.cs.washington.edu/triviaqa/data/triviaqa-{}.tar.gz")
_TRAIN_FILE_FORMAT = "*-train.json"
_VALIDATION_FILE_FORMAT = "*-dev.json"
_TEST_FILE_FORMAT = "*test-without-answers.json"
_WEB_EVIDENCE_DIR = "evidence/web"
_WIKI_EVIDENCE_DIR = "evidence/wikipedia"
_DESCRIPTION = """\
TriviaqQA is a reading comprehension dataset containing over 650K
question-answer-evidence triples. TriviaqQA includes 95K question-answer
pairs authored by trivia enthusiasts and independently gathered evidence
documents, six per question on average, that provide high quality distant
supervision for answering the questions.
"""
_RC_DESCRIPTION = """\
Question-answer pairs where all documents for a given question contain the
answer string(s).
"""
_UNFILTERED_DESCRIPTION = """\
110k question-answer pairs for open domain QA where not all documents for a
given question contain the answer string(s). This makes the unfiltered dataset
more appropriate for IR-style QA.
"""
_CONTEXT_ADDENDUM = "Includes context from Wikipedia and search results."
def _web_evidence_dir(tmp_dir):
return tf.io.gfile.glob(os.path.join(tmp_dir, _WEB_EVIDENCE_DIR))
def _wiki_evidence_dir(tmp_dir):
return tf.io.gfile.glob(os.path.join(tmp_dir, _WIKI_EVIDENCE_DIR))
class TriviaQAConfig(tfds.core.BuilderConfig):
"""BuilderConfig for TriviaQA."""
def __init__(self, *, unfiltered=False, exclude_context=False, **kwargs):
"""BuilderConfig for TriviaQA.
Args:
unfiltered: bool, whether to use the unfiltered version of the dataset,
intended for open-domain QA.
exclude_context: bool, whether to exclude Wikipedia and search context for
reduced size.
**kwargs: keyword arguments forwarded to super.
"""
name = "unfiltered" if unfiltered else "rc"
if exclude_context:
name += ".nocontext"
description = _UNFILTERED_DESCRIPTION if unfiltered else _RC_DESCRIPTION
if not exclude_context:
description += _CONTEXT_ADDENDUM
super(TriviaQAConfig, self).__init__(
name=name,
description=description,
version=tfds.core.Version("1.1.1"),
**kwargs)
self.unfiltered = unfiltered
self.exclude_context = exclude_context
class BigBirdTriviaQAConfig(tfds.core.BuilderConfig):
"""BuilderConfig for TriviaQA."""
def __init__(self, **kwargs):
"""BuilderConfig for TriviaQA.
Args:
**kwargs: keyword arguments forwarded to super.
"""
name = "rc_wiki.preprocessed"
description = _RC_DESCRIPTION
super(BigBirdTriviaQAConfig, self).__init__(
name=name,
description=description,
version=tfds.core.Version("1.1.1"),
**kwargs)
self.unfiltered = False
self.exclude_context = False
def configure(self,
sentencepiece_model_path,
sequence_length,
stride,
global_sequence_length=None):
"""Configures additional user-specified arguments."""
self.sentencepiece_model_path = sentencepiece_model_path
self.sequence_length = sequence_length
self.stride = stride
if global_sequence_length is None and sequence_length is not None:
self.global_sequence_length = sequence_length // 16 + 64
else:
self.global_sequence_length = global_sequence_length
logging.info(
"""
global_sequence_length: %s
sequence_length: %s
stride: %s
sentencepiece_model_path: %s""",
self.global_sequence_length, self.sequence_length,
self.stride, self.sentencepiece_model_path)
def validate(self):
"""Validates that user specifies valid arguments."""
if self.sequence_length is None:
raise ValueError("sequence_length must be specified for BigBird.")
if self.stride is None:
raise ValueError("stride must be specified for BigBird.")
if self.sentencepiece_model_path is None:
raise ValueError(
"sentencepiece_model_path must be specified for BigBird.")
def filter_files_for_big_bird(files):
filtered_files = [f for f in files if os.path.basename(f).startswith("wiki")]
assert len(filtered_files) == 1, "There should only be one wikipedia file."
return filtered_files
class TriviaQA(tfds.core.BeamBasedBuilder):
"""TriviaQA is a reading comprehension dataset.
It containss over 650K question-answer-evidence triples.
"""
name = "bigbird_trivia_qa"
BUILDER_CONFIGS = [
BigBirdTriviaQAConfig(),
TriviaQAConfig(unfiltered=False, exclude_context=False), # rc
TriviaQAConfig(unfiltered=False, exclude_context=True), # rc.nocontext
TriviaQAConfig(unfiltered=True, exclude_context=False), # unfiltered
TriviaQAConfig(unfiltered=True, exclude_context=True),
# unfilered.nocontext
]
def __init__(self,
*,
sentencepiece_model_path=None,
sequence_length=None,
stride=None,
global_sequence_length=None,
**kwargs):
super(TriviaQA, self).__init__(**kwargs)
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
self.builder_config.configure(
sentencepiece_model_path=sentencepiece_model_path,
sequence_length=sequence_length,
stride=stride,
global_sequence_length=global_sequence_length)
def _info(self):
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
supervised_keys=None,
homepage="http://nlp.cs.washington.edu/triviaqa/",
citation=_CITATION,
features=tfds.features.FeaturesDict({
"id": tfds.features.Text(),
"qid": tfds.features.Text(),
"question": tfds.features.Text(),
"context": tfds.features.Text(),
# Sequence features.
"token_ids": tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"token_offsets":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"segment_ids":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
"global_token_ids":
tfds.features.Tensor(shape=(None,), dtype=tf.int64),
# Start and end indices (inclusive).
"answers":
tfds.features.Tensor(shape=(None, 2), dtype=tf.int64),
}))
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict({
"question":
tfds.features.Text(),
"question_id":
tfds.features.Text(),
"question_source":
tfds.features.Text(),
"entity_pages":
tfds.features.Sequence({
"doc_source":
tfds.features.Text(),
"filename":
tfds.features.Text(),
"title":
tfds.features.Text(),
"wiki_context":
tfds.features.Text(),
}),
"search_results":
tfds.features.Sequence({
"description":
tfds.features.Text(),
"filename":
tfds.features.Text(),
"rank":
tf.int32,
"title":
tfds.features.Text(),
"url":
tfds.features.Text(),
"search_context":
tfds.features.Text(),
}),
"answer":
tfds.features.FeaturesDict({
"aliases":
tfds.features.Sequence(tfds.features.Text()),
"normalized_aliases":
tfds.features.Sequence(tfds.features.Text()),
"matched_wiki_entity_name":
tfds.features.Text(),
"normalized_matched_wiki_entity_name":
tfds.features.Text(),
"normalized_value":
tfds.features.Text(),
"type":
tfds.features.Text(),
"value":
tfds.features.Text(),
}),
}),
supervised_keys=None,
homepage="http://nlp.cs.washington.edu/triviaqa/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
cfg = self.builder_config
download_urls = dict()
if not (cfg.unfiltered and cfg.exclude_context):
download_urls["rc"] = _DOWNLOAD_URL_TMPL.format("rc")
if cfg.unfiltered:
download_urls["unfiltered"] = _DOWNLOAD_URL_TMPL.format("unfiltered")
file_paths = dl_manager.download_and_extract(download_urls)
qa_dir = (
os.path.join(file_paths["unfiltered"], "triviaqa-unfiltered")
if cfg.unfiltered else
os.path.join(file_paths["rc"], "qa"))
train_files = tf.io.gfile.glob(os.path.join(qa_dir, _TRAIN_FILE_FORMAT))
valid_files = tf.io.gfile.glob(
os.path.join(qa_dir, _VALIDATION_FILE_FORMAT))
test_files = tf.io.gfile.glob(os.path.join(qa_dir, _TEST_FILE_FORMAT))
if cfg.exclude_context:
web_evidence_dir = None
wiki_evidence_dir = None
else:
web_evidence_dir = os.path.join(file_paths["rc"], _WEB_EVIDENCE_DIR)
wiki_evidence_dir = os.path.join(file_paths["rc"], _WIKI_EVIDENCE_DIR)
if isinstance(cfg, BigBirdTriviaQAConfig):
train_files = filter_files_for_big_bird(train_files)
valid_files = filter_files_for_big_bird(valid_files)
test_files = filter_files_for_big_bird(test_files)
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs={"files": train_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": True}),
tfds.core.SplitGenerator(
name=tfds.Split.VALIDATION,
gen_kwargs={"files": valid_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": True}),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={"files": test_files,
"web_dir": web_evidence_dir,
"wiki_dir": wiki_evidence_dir,
"answer": False}),
]
def _build_pcollection(self, pipeline, files, web_dir, wiki_dir, answer):
if isinstance(self.builder_config, BigBirdTriviaQAConfig):
self.builder_config.validate()
question_answers = preprocess.read_question_answers(files[0])
return preprocess.make_pipeline(
pipeline,
question_answers=question_answers,
answer=answer,
max_num_tokens=self.builder_config.sequence_length,
max_num_global_tokens=self.builder_config.global_sequence_length,
stride=self.builder_config.stride,
sentencepiece_model_path=self.builder_config.sentencepiece_model_path,
wikipedia_dir=wiki_dir,
web_dir=web_dir)
parse_example_fn = functools.partial(parse_example,
self.builder_config.exclude_context,
web_dir, wiki_dir)
return (pipeline
| beam.Create(files)
| beam.ParDo(ReadQuestions())
| beam.Reshuffle()
| beam.Map(parse_example_fn))
class ReadQuestions(beam.DoFn):
"""Read questions from JSON."""
def process(self, file):
with tf.io.gfile.GFile(file) as f:
data = json.load(f)
for question in data["Data"]:
example = {"SourceFile": os.path.basename(file)}
example.update(question)
yield example
def parse_example(exclude_context, web_dir, wiki_dir, article):
"""Return a single example from an article JSON record."""
def _strip(collection):
return [item.strip() for item in collection]
if "Answer" in article:
answer = article["Answer"]
answer_dict = {
"aliases":
_strip(answer["Aliases"]),
"normalized_aliases":
_strip(answer["NormalizedAliases"]),
"matched_wiki_entity_name":
answer.get("MatchedWikiEntryName", "").strip(),
"normalized_matched_wiki_entity_name":
answer.get("NormalizedMatchedWikiEntryName", "").strip(),
"normalized_value":
answer["NormalizedValue"].strip(),
"type":
answer["Type"].strip(),
"value":
answer["Value"].strip(),
}
else:
answer_dict = {
"aliases": [],
"normalized_aliases": [],
"matched_wiki_entity_name": "<unk>",
"normalized_matched_wiki_entity_name": "<unk>",
"normalized_value": "<unk>",
"type": "",
"value": "<unk>",
}
if exclude_context:
article["SearchResults"] = []
article["EntityPages"] = []
def _add_context(collection, context_field, file_dir):
"""Adds context from file, or skips if file does not exist."""
new_items = []
for item in collection:
if "Filename" not in item:
logging.info("Missing context 'Filename', skipping.")
continue
new_item = item.copy()
fname = item["Filename"]
try:
with tf.io.gfile.GFile(os.path.join(file_dir, fname)) as f:
new_item[context_field] = f.read()
except (IOError, tf.errors.NotFoundError):
logging.info("File does not exist, skipping: %s", fname)
continue
new_items.append(new_item)
return new_items
def _strip_if_str(v):
return v.strip() if isinstance(v, six.string_types) else v
def _transpose_and_strip_dicts(dicts, field_names):
return {
tfds.core.naming.camelcase_to_snakecase(k):
[_strip_if_str(d[k]) for d in dicts] for k in field_names
}
search_results = _transpose_and_strip_dicts(
_add_context(article.get("SearchResults", []), "SearchContext", web_dir),
["Description", "Filename", "Rank", "Title", "Url", "SearchContext"])
entity_pages = _transpose_and_strip_dicts(
_add_context(article.get("EntityPages", []), "WikiContext", wiki_dir),
["DocSource", "Filename", "Title", "WikiContext"])
question = article["Question"].strip()
question_id = article["QuestionId"]
question_source = article["QuestionSource"].strip()
return f"{article['SourceFile']}_{question_id}", {
"entity_pages": entity_pages,
"search_results": search_results,
"question": question,
"question_id": question_id,
"question_source": question_source,
"answer": answer_dict,
}
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Downloads and prepares TriviaQA dataset."""
from unittest import mock
from absl import app
from absl import flags
from absl import logging
import apache_beam as beam
import tensorflow_datasets as tfds
from official.nlp.projects.triviaqa import dataset # pylint: disable=unused-import
flags.DEFINE_integer('sequence_length', 4096, 'Max number of tokens.')
flags.DEFINE_integer(
'global_sequence_length', None,
'Max number of question tokens plus sentences. If not set, defaults to '
'sequence_length // 16 + 64.')
flags.DEFINE_integer(
'stride', 3072,
'For documents longer than `sequence_length`, where to split them.')
flags.DEFINE_string(
'sentencepiece_model_path', None,
'SentencePiece model to use for tokenization.')
flags.DEFINE_string('data_dir', None, 'Data directory for TFDS.')
flags.DEFINE_string('runner', 'DirectRunner', 'Beam runner to use.')
FLAGS = flags.FLAGS
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
builder = tfds.builder(
'bigbird_trivia_qa/rc_wiki.preprocessed',
data_dir=FLAGS.data_dir,
sentencepiece_model_path=FLAGS.sentencepiece_model_path,
sequence_length=FLAGS.sequence_length,
global_sequence_length=FLAGS.global_sequence_length,
stride=FLAGS.stride)
download_config = tfds.download.DownloadConfig(
beam_options=beam.options.pipeline_options.PipelineOptions(flags=[
f'--runner={FLAGS.runner}',
'--direct_num_workers=8',
'--direct_running_mode=multi_processing',
]))
with mock.patch('tensorflow_datasets.core.download.extractor._normpath',
new=lambda x: x):
builder.download_and_prepare(download_config=download_config)
logging.info(builder.info.splits)
if __name__ == '__main__':
flags.mark_flag_as_required('sentencepiece_model_path')
app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evalutes TriviaQA predictions."""
import json
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.projects.triviaqa import evaluation
flags.DEFINE_string('gold_path', None,
'Path to golden validation, i.e. wikipedia-dev.json.')
flags.DEFINE_string('predictions_path', None,
'Path to predictions in JSON format')
FLAGS = flags.FLAGS
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
with tf.io.gfile.GFile(FLAGS.gold_path) as f:
ground_truth = {
datum['QuestionId']: datum['Answer'] for datum in json.load(f)['Data']
}
with tf.io.gfile.GFile(FLAGS.predictions_path) as f:
predictions = json.load(f)
logging.info(evaluation.evaluate_triviaqa(ground_truth, predictions))
if __name__ == '__main__':
flags.mark_flag_as_required('predictions_path')
app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 Google LLC
# Copyright 2017 Mandar Joshi (mandar90@cs.washington.edu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Official evaluation script for v1.0 of the TriviaQA dataset.
Forked from
https://github.com/mandarjoshi90/triviaqa/blob/master/evaluation/triviaqa_evaluation.py.
Modifications are removal of main function.
"""
import collections
import re
import string
import sys
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def handle_punc(text):
exclude = set(string.punctuation + ''.join([u'‘', u'’', u'´', u'`']))
return ''.join(ch if ch not in exclude else ' ' for ch in text)
def lower(text):
return text.lower()
def replace_underscore(text):
return text.replace('_', ' ')
return white_space_fix(
remove_articles(handle_punc(lower(replace_underscore(s))))).strip()
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = (
collections.Counter(prediction_tokens)
& collections.Counter(ground_truth_tokens))
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def is_exact_match(answer_object, prediction):
ground_truths = get_ground_truths(answer_object)
for ground_truth in ground_truths:
if exact_match_score(prediction, ground_truth):
return True
return False
def has_exact_match(ground_truths, candidates):
for ground_truth in ground_truths:
if ground_truth in candidates:
return True
return False
def get_ground_truths(answer):
return answer['NormalizedAliases'] + [
normalize_answer(ans) for ans in answer.get('HumanAnswers', [])
]
def get_oracle_score(ground_truth,
predicted_answers,
qid_list=None,
mute=False):
exact_match = common = 0
if qid_list is None:
qid_list = ground_truth.keys()
for qid in qid_list:
if qid not in predicted_answers:
if not mute:
message = 'Irrelavant question {} will receive score 0.'.format(qid)
print(message, file=sys.stderr)
continue
common += 1
prediction = normalize_answer(predicted_answers[qid])
ground_truths = get_ground_truths(ground_truth[qid])
em_for_this_question = has_exact_match(ground_truths, prediction)
exact_match += int(em_for_this_question)
exact_match = 100.0 * exact_match / len(qid_list)
return {
'oracle_exact_match': exact_match,
'common': common,
'denominator': len(qid_list),
'pred_len': len(predicted_answers),
'gold_len': len(ground_truth)
}
def evaluate_triviaqa(ground_truth,
predicted_answers,
qid_list=None,
mute=False):
f1 = exact_match = common = 0
if qid_list is None:
qid_list = ground_truth.keys()
for qid in qid_list:
if qid not in predicted_answers:
if not mute:
message = 'Missed question {} will receive score 0.'.format(qid)
print(message, file=sys.stderr)
continue
if qid not in ground_truth:
if not mute:
message = 'Irrelavant question {} will receive score 0.'.format(qid)
print(message, file=sys.stderr)
continue
common += 1
prediction = predicted_answers[qid]
ground_truths = get_ground_truths(ground_truth[qid])
em_for_this_question = metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths)
if em_for_this_question == 0 and not mute:
print('em=0:', prediction, ground_truths)
exact_match += em_for_this_question
f1_for_this_question = metric_max_over_ground_truths(
f1_score, prediction, ground_truths)
f1 += f1_for_this_question
exact_match = 100.0 * exact_match / len(qid_list)
f1 = 100.0 * f1 / len(qid_list)
return {
'exact_match': exact_match,
'f1': f1,
'common': common,
'denominator': len(qid_list),
'pred_len': len(predicted_answers),
'gold_len': len(ground_truth)
}
This diff is collapsed.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for TriviaQA."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.configs import encoders
class TriviaQaHead(tf.keras.layers.Layer):
"""Computes logits given token and global embeddings."""
def __init__(self,
intermediate_size,
intermediate_activation=tf_utils.get_activation('gelu'),
dropout_rate=0.0,
attention_dropout_rate=0.0,
**kwargs):
super(TriviaQaHead, self).__init__(**kwargs)
self._attention_dropout = tf.keras.layers.Dropout(attention_dropout_rate)
self._intermediate_dense = tf.keras.layers.Dense(intermediate_size)
self._intermediate_activation = tf.keras.layers.Activation(
intermediate_activation)
self._output_dropout = tf.keras.layers.Dropout(dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization()
self._logits_dense = tf.keras.layers.Dense(2)
def build(self, input_shape):
output_shape = input_shape['token_embeddings'][-1]
self._output_dense = tf.keras.layers.Dense(output_shape)
super(TriviaQaHead, self).build(input_shape)
def call(self, inputs, training=None):
token_embeddings = inputs['token_embeddings']
token_ids = inputs['token_ids']
question_lengths = inputs['question_lengths']
x = self._attention_dropout(token_embeddings, training=training)
intermediate_outputs = self._intermediate_dense(x)
intermediate_outputs = self._intermediate_activation(intermediate_outputs)
outputs = self._output_dense(intermediate_outputs)
outputs = self._output_dropout(outputs, training=training)
outputs = self._output_layer_norm(outputs + token_embeddings)
logits = self._logits_dense(outputs)
logits -= tf.expand_dims(
tf.cast(tf.equal(token_ids, 0), tf.float32) + tf.sequence_mask(
question_lengths, logits.shape[-2], dtype=tf.float32), -1) * 1e6
return logits
class TriviaQaModel(tf.keras.Model):
"""Model for TriviaQA."""
def __init__(self, model_config: encoders.EncoderConfig, sequence_length: int,
**kwargs):
inputs = dict(
token_ids=tf.keras.Input((sequence_length,), dtype=tf.int32),
question_lengths=tf.keras.Input((), dtype=tf.int32))
encoder = encoders.build_encoder(model_config)
x = encoder(
dict(
input_word_ids=inputs['token_ids'],
input_mask=tf.cast(inputs['token_ids'] > 0, tf.int32),
input_type_ids=1 -
tf.sequence_mask(inputs['question_lengths'], sequence_length,
tf.int32)))['sequence_output']
logits = TriviaQaHead(
model_config.get().intermediate_size,
dropout_rate=model_config.get().dropout_rate,
attention_dropout_rate=model_config.get().attention_dropout_rate)(
dict(
token_embeddings=x,
token_ids=inputs['token_ids'],
question_lengths=inputs['question_lengths']))
super(TriviaQaModel, self).__init__(inputs, logits, **kwargs)
self._encoder = encoder
@property
def encoder(self):
return self._encoder
class SpanOrCrossEntropyLoss(tf.keras.losses.Loss):
"""Cross entropy loss for multiple correct answers.
See https://arxiv.org/abs/1710.10723.
"""
def call(self, y_true, y_pred):
y_pred_masked = y_pred - tf.cast(y_true < 0.5, tf.float32) * 1e6
or_cross_entropy = (
tf.math.reduce_logsumexp(y_pred, axis=-2) -
tf.math.reduce_logsumexp(y_pred_masked, axis=-2))
return tf.math.reduce_sum(or_cross_entropy, -1)
def smooth_labels(label_smoothing, labels, question_lengths, token_ids):
mask = 1. - (
tf.cast(tf.equal(token_ids, 0), tf.float32) +
tf.sequence_mask(question_lengths, labels.shape[-2], dtype=tf.float32))
num_classes = tf.expand_dims(tf.math.reduce_sum(mask, -1, keepdims=True), -1)
labels = (1. - label_smoothing) * labels + (label_smoothing / num_classes)
return labels * tf.expand_dims(mask, -1)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA script for inference."""
import collections
import contextlib
import functools
import json
import operator
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from official.nlp.configs import encoders # pylint: disable=unused-import
from official.nlp.projects.triviaqa import evaluation
from official.nlp.projects.triviaqa import inputs
from official.nlp.projects.triviaqa import prediction
flags.DEFINE_string('data_dir', None, 'TensorFlow Datasets directory.')
flags.DEFINE_enum('split', None,
[tfds.Split.TRAIN, tfds.Split.VALIDATION, tfds.Split.TEST],
'For which split to generate predictions.')
flags.DEFINE_string('predictions_path', None, 'Output for predictions.')
flags.DEFINE_string('sentencepiece_model_path', None,
'Path to sentence piece model.')
flags.DEFINE_integer('bigbird_block_size', 64,
'Size of blocks for sparse block attention.')
flags.DEFINE_string('saved_model_dir', None,
'Path from which to initialize model and weights.')
flags.DEFINE_integer('sequence_length', 4096, 'Maximum number of tokens.')
flags.DEFINE_integer('global_sequence_length', 320,
'Maximum number of global tokens.')
flags.DEFINE_integer('batch_size', 32, 'Size of batch.')
flags.DEFINE_string('master', '', 'Address of the TPU master.')
flags.DEFINE_integer('decode_top_k', 8,
'Maximum number of tokens to consider for begin/end.')
flags.DEFINE_integer('decode_max_size', 16,
'Maximum number of sentence pieces in an answer.')
FLAGS = flags.FLAGS
@contextlib.contextmanager
def worker_context():
if FLAGS.master:
with tf.device('/job:worker') as d:
yield d
else:
yield
def read_sentencepiece_model(path):
with tf.io.gfile.GFile(path, 'rb') as file:
processor = spm.SentencePieceProcessor()
processor.LoadFromSerializedProto(file.read())
return processor
def predict(sp_processor, features_map_fn, logits_fn, decode_logits_fn,
split_and_pad_fn, distribute_strategy, dataset):
"""Make predictions."""
predictions = collections.defaultdict(list)
for _, features in dataset.enumerate():
token_ids = features['token_ids']
x = split_and_pad_fn(features_map_fn(features))
logits = tf.concat(
distribute_strategy.experimental_local_results(logits_fn(x)), 0)
logits = logits[:features['token_ids'].shape[0]]
end_limit = token_ids.row_lengths() - 1 # inclusive
begin, end, scores = decode_logits_fn(logits, end_limit)
answers = prediction.decode_answer(features['context'], begin, end,
features['token_offsets'],
end_limit).numpy()
for j, (qid, token_id, offset, score, answer) in enumerate(
zip(features['qid'].numpy(),
tf.gather(features['token_ids'], begin, batch_dims=1).numpy(),
tf.gather(features['token_offsets'], begin, batch_dims=1).numpy(),
scores, answers)):
if not answer:
logging.info('%s: %s | NO_ANSWER, %f',
features['id'][j].numpy().decode('utf-8'),
features['question'][j].numpy().decode('utf-8'), score)
continue
if sp_processor.IdToPiece(int(token_id)).startswith('▁') and offset > 0:
answer = answer[1:]
logging.info('%s: %s | %s, %f', features['id'][j].numpy().decode('utf-8'),
features['question'][j].numpy().decode('utf-8'),
answer.decode('utf-8'), score)
predictions[qid.decode('utf-8')].append((score, answer.decode('utf-8')))
predictions = {
qid: evaluation.normalize_answer(
sorted(answers, key=operator.itemgetter(0), reverse=True)[0][1])
for qid, answers in predictions.items()
}
return predictions
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Configure input processing.
sp_processor = read_sentencepiece_model(FLAGS.sentencepiece_model_path)
features_map_fn = tf.function(
functools.partial(
inputs.features_map_fn,
local_radius=FLAGS.bigbird_block_size,
relative_pos_max_distance=24,
use_hard_g2l_mask=True,
sequence_length=FLAGS.sequence_length,
global_sequence_length=FLAGS.global_sequence_length,
padding_id=sp_processor.PieceToId('<pad>'),
eos_id=sp_processor.PieceToId('</s>'),
null_id=sp_processor.PieceToId('<empty>'),
cls_id=sp_processor.PieceToId('<ans>'),
sep_id=sp_processor.PieceToId('<sep_0>')),
autograph=False)
# Connect to TPU cluster.
if FLAGS.master:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
else:
strategy = tf.distribute.MirroredStrategy()
# Initialize datasets.
with worker_context():
_ = tf.random.get_global_generator()
dataset = inputs.read_batches(
FLAGS.data_dir, FLAGS.split, FLAGS.batch_size, include_answers=False)
# Initialize model and compile.
with strategy.scope():
model = tf.keras.models.load_model(FLAGS.saved_model_dir, compile=False)
logging.info('Model initialized. Beginning prediction loop.')
logits_fn = tf.function(
functools.partial(prediction.distributed_logits_fn, model))
decode_logits_fn = tf.function(
functools.partial(prediction.decode_logits, FLAGS.decode_top_k,
FLAGS.decode_max_size))
split_and_pad_fn = tf.function(
functools.partial(prediction.split_and_pad, strategy, FLAGS.batch_size))
# Prediction strategy.
predict_fn = functools.partial(
predict,
sp_processor=sp_processor,
features_map_fn=features_map_fn,
logits_fn=logits_fn,
decode_logits_fn=decode_logits_fn,
split_and_pad_fn=split_and_pad_fn,
distribute_strategy=strategy,
dataset=dataset)
with worker_context():
predictions = predict_fn()
with tf.io.gfile.GFile(FLAGS.predictions_path, 'w') as f:
json.dump(predictions, f)
if __name__ == '__main__':
flags.mark_flags_as_required(['split', 'predictions_path', 'saved_model_dir'])
app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for inference."""
import tensorflow as tf
def split_and_pad(strategy, batch_size, x):
"""Split and pad for interence."""
per_replica_size = batch_size // strategy.num_replicas_in_sync
def slice_fn(x, i):
begin = min(x.shape[0], i * per_replica_size)
end = min(x.shape[0], (i + 1) * per_replica_size)
indices = tf.range(begin, end, dtype=tf.int32)
return tf.gather(x, tf.pad(indices, [[0, per_replica_size - end + begin]]))
# pylint: disable=g-long-lambda
return tf.nest.map_structure(
lambda x: strategy.experimental_distribute_values_from_function(
lambda ctx: slice_fn(x, ctx.replica_id_in_sync_group)), x)
# pylint: enable=g-long-lambda
def decode_logits(top_k, max_size, logits, default):
"""Get the span from logits."""
logits = tf.transpose(logits, [0, 2, 1])
values, indices = tf.math.top_k(logits, top_k)
width = (
tf.expand_dims(indices[:, 1, :], -2) -
tf.expand_dims(indices[:, 0, :], -1))
mask = tf.logical_and(width >= 0, width <= max_size)
scores = (
tf.expand_dims(values[:, 0, :], -1) + tf.expand_dims(values[:, 1, :], -2))
scores = tf.where(mask, scores, -1e8)
flat_indices = tf.argmax(tf.reshape(scores, (-1, top_k * top_k)), -1)
begin = tf.gather(
indices[:, 0, :], tf.math.floordiv(flat_indices, top_k), batch_dims=1)
end = tf.gather(
indices[:, 1, :], tf.math.mod(flat_indices, top_k), batch_dims=1)
reduced_mask = tf.math.reduce_any(mask, [-1, -2])
return (tf.where(reduced_mask, begin,
default), tf.where(reduced_mask, end, default),
tf.math.reduce_max(scores, [-1, -2]))
@tf.function
def decode_answer(context, begin, end, token_offsets, end_limit):
i = tf.gather(token_offsets, begin, batch_dims=1)
j = tf.gather(token_offsets, tf.minimum(end + 1, end_limit), batch_dims=1)
j = tf.where(end == end_limit, tf.cast(tf.strings.length(context), tf.int64),
j)
return tf.strings.substr(context, i, j - i)
def distributed_logits_fn(model, x):
return model.distribute_strategy.run(
lambda x: model(x, training=False), args=(x,))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for preprocessing TriviaQA data."""
import bisect
import json
import operator
import os
import re
import string
from typing import Any, Dict, Generator, List, Optional, Set, Text, Tuple
from absl import logging
import apache_beam as beam
from apache_beam import metrics
import dataclasses
import nltk
import numpy as np
import tensorflow.io.gfile as gfile
import sentencepiece as spm
from official.nlp.projects.triviaqa import evaluation
from official.nlp.projects.triviaqa import sentencepiece_pb2
@dataclasses.dataclass
class Question(object):
id: Text
value: Text
@dataclasses.dataclass
class EvidenceInfo(object):
id: Text
source: Text
title: Text
@dataclasses.dataclass
class Evidence(object):
info: EvidenceInfo
text: Text
@dataclasses.dataclass
class Answer(object):
value: Text
aliases: List[Text]
normalized_aliases: List[Text]
@dataclasses.dataclass
class QuestionAnswer(object):
question: Question
evidence_info: List[EvidenceInfo]
answer: Optional[Answer] = None
@dataclasses.dataclass
class QuestionAnswerEvidence(object):
question: Question
evidence: Evidence
answer: Optional[Answer] = None
@dataclasses.dataclass
class Features(object):
id: Text
stride_index: int
question_id: Text
question: Text
context: bytes
token_ids: List[int]
token_offsets: List[int]
global_token_ids: List[int]
segment_ids: List[int]
@dataclasses.dataclass
class Paragraph(object):
sentences: List[sentencepiece_pb2.SentencePieceText]
size: int
@dataclasses.dataclass
class AnswerSpan(object):
begin: int # inclusive
end: int # inclusive
text: Text
def make_paragraph(
sentence_tokenizer: nltk.tokenize.api.TokenizerI,
processor: spm.SentencePieceProcessor,
text: Text,
paragraph_metric: Optional[metrics.Metrics.DelegatingDistribution] = None,
sentence_metric: Optional[metrics.Metrics.DelegatingDistribution] = None
) -> Paragraph:
"""Tokenizes paragraphs."""
paragraph_size = 0
sentences = []
for sentence in sentence_tokenizer.tokenize(text):
sentencepiece_text = sentencepiece_pb2.SentencePieceText.FromString(
processor.EncodeAsSerializedProto(sentence))
paragraph_size += len(sentencepiece_text.pieces)
sentences.append(sentencepiece_text)
if sentence_metric:
sentence_metric.update(len(sentencepiece_text.pieces))
if paragraph_metric:
paragraph_metric.update(paragraph_size)
return Paragraph(sentences=sentences, size=paragraph_size)
def read_question_answers(json_path: Text) -> List[QuestionAnswer]:
"""Read question answers."""
with gfile.GFile(json_path) as f:
data = json.load(f)['Data']
question_answers = []
for datum in data:
question = Question(id=datum['QuestionId'], value=datum['Question'])
if 'Answer' in datum:
answer = Answer(
value=datum['Answer']['Value'],
aliases=datum['Answer']['Aliases'],
normalized_aliases=datum['Answer']['NormalizedAliases'])
else:
answer = None
evidence_info = []
for key in ['EntityPages', 'SearchResults']:
for document in datum.get(key, []):
evidence_info.append(
EvidenceInfo(
id=document['Filename'], title=document['Title'], source=key))
question_answers.append(
QuestionAnswer(
question=question, evidence_info=evidence_info, answer=answer))
return question_answers
def alias_answer(answer: Text, include=None):
alias = answer.replace('_', ' ').lower()
exclude = set(string.punctuation + ''.join(['‘', '’', '´', '`']))
include = include or []
alias = ''.join(c if c not in exclude or c in include else ' ' for c in alias)
return ' '.join(alias.split()).strip()
def make_answer_set(answer: Answer) -> Set[Text]:
"""Apply less aggressive normalization to the answer aliases."""
answers = []
for alias in [answer.value] + answer.aliases:
answers.append(alias_answer(alias))
answers.append(alias_answer(alias, [',', '.']))
answers.append(alias_answer(alias, ['-']))
answers.append(alias_answer(alias, [',', '.', '-']))
answers.append(alias_answer(alias, string.punctuation))
return set(answers + answer.normalized_aliases)
def find_answer_spans(text: bytes, answer_set: Set[Text]) -> List[AnswerSpan]:
"""Find answer spans."""
spans = []
for answer in answer_set:
answer_regex = re.compile(
re.escape(answer).encode('utf-8').replace(b'\\ ', b'[ -]'),
flags=re.IGNORECASE)
for match in re.finditer(answer_regex, text):
spans.append(
AnswerSpan(
begin=match.start(),
end=match.end(),
text=match.group(0).decode('utf-8')))
return sorted(spans, key=operator.attrgetter('begin'))
def realign_answer_span(features: Features, answer_set: Optional[Set[Text]],
processor: spm.SentencePieceProcessor,
span: AnswerSpan) -> Optional[AnswerSpan]:
"""Align answer span to text with given tokens."""
i = bisect.bisect_left(features.token_offsets, span.begin)
if i == len(features.token_offsets) or span.begin < features.token_offsets[i]:
i -= 1
j = i + 1
answer_end = span.begin + len(span.text.encode('utf-8'))
while (j < len(features.token_offsets) and
features.token_offsets[j] < answer_end):
j += 1
j -= 1
sp_answer = (
features.context[features.token_offsets[i]:features.token_offsets[j + 1]]
if j + 1 < len(features.token_offsets) else
features.context[features.token_offsets[i]:])
if (processor.IdToPiece(features.token_ids[i]).startswith('▁') and
features.token_offsets[i] > 0):
sp_answer = sp_answer[1:]
sp_answer = evaluation.normalize_answer(sp_answer.decode('utf-8'))
if answer_set is not None and sp_answer not in answer_set:
# No need to warn if the cause was breaking word boundaries.
if len(sp_answer) and not len(sp_answer) > len(
evaluation.normalize_answer(span.text)):
logging.warning('%s: "%s" not in %s.', features.question_id, sp_answer,
answer_set)
return None
return AnswerSpan(begin=i, end=j, text=span.text)
def read_sentencepiece_model(path):
with gfile.GFile(path, 'rb') as file:
processor = spm.SentencePieceProcessor()
processor.LoadFromSerializedProto(file.read())
return processor
class ReadEvidence(beam.DoFn):
"""Function to read evidence."""
def __init__(self, wikipedia_dir: Text, web_dir: Text):
self._wikipedia_dir = wikipedia_dir
self._web_dir = web_dir
def process(
self, question_answer: QuestionAnswer
) -> Generator[QuestionAnswerEvidence, None, None]:
for info in question_answer.evidence_info:
if info.source == 'EntityPages':
evidence_path = os.path.join(self._wikipedia_dir, info.id)
elif info.source == 'SearchResult':
evidence_path = os.path.join(self._web_dir, info.id)
else:
raise ValueError(f'Unknown evidence source: {info.source}.')
with gfile.GFile(evidence_path, 'rb') as f:
text = f.read().decode('utf-8')
metrics.Metrics.counter('_', 'documents').inc()
yield QuestionAnswerEvidence(
question=question_answer.question,
evidence=Evidence(info=info, text=text),
answer=question_answer.answer)
_CLS_PIECE = '<ans>'
_EOS_PIECE = '</s>'
_SEP_PIECE = '<sep_0>'
# _PARAGRAPH_SEP_PIECE = '<sep_1>'
_NULL_PIECE = '<empty>'
_QUESTION_PIECE = '<unused_34>'
class MakeFeatures(beam.DoFn):
"""Function to make features."""
def __init__(self, sentencepiece_model_path: Text, max_num_tokens: int,
max_num_global_tokens: int, stride: int):
self._sentencepiece_model_path = sentencepiece_model_path
self._max_num_tokens = max_num_tokens
self._max_num_global_tokens = max_num_global_tokens
self._stride = stride
def setup(self):
self._sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
self._sentencepiece_processor = read_sentencepiece_model(
self._sentencepiece_model_path)
def _make_features(self, stride_index: int, paragraph_texts: List[Text],
paragraphs: List[Paragraph],
question_answer_evidence: QuestionAnswerEvidence,
ids: List[int],
paragraph_offset: int) -> Tuple[int, Features]:
global_ids = (
[self._sentencepiece_processor.PieceToId(_CLS_PIECE)] +
[self._sentencepiece_processor.PieceToId(_QUESTION_PIECE)] * len(ids))
segment_ids = [i + 1 for i in range(len(ids))] # offset for CLS token
token_ids, sentences = [], []
offsets, offset, full_text = [-1] * len(ids), 0, True
for i in range(paragraph_offset, len(paragraph_texts)):
if i < len(paragraphs):
paragraph = paragraphs[i]
else:
paragraphs.append(
make_paragraph(
self._sentence_tokenizer,
self._sentencepiece_processor,
paragraph_texts[i],
paragraph_metric=metrics.Metrics.distribution(
'_', 'paragraphs'),
sentence_metric=metrics.Metrics.distribution('_', 'sentences')))
paragraph = paragraphs[-1]
for sentence in paragraph.sentences:
if (len(ids) + len(token_ids) + len(sentence.pieces) + 1 >=
self._max_num_tokens or
len(global_ids) >= self._max_num_global_tokens):
full_text = False
break
for j, piece in enumerate(sentence.pieces):
token_ids.append(piece.id)
segment_ids.append(len(global_ids))
offsets.append(offset + piece.begin)
if j == 0 and sentences:
offsets[-1] -= 1
offset += len(sentence.text.encode('utf-8')) + 1
global_ids.append(self._sentencepiece_processor.PieceToId(_EOS_PIECE))
sentences.append(sentence.text)
if not full_text:
break
context = ' '.join(sentences).encode('utf-8')
token_ids.append(self._sentencepiece_processor.PieceToId(_NULL_PIECE))
offsets.append(len(context))
segment_ids.append(0)
next_paragraph_index = len(paragraph_texts)
if not full_text and self._stride > 0:
shift = paragraphs[paragraph_offset].size
next_paragraph_index = paragraph_offset + 1
while (next_paragraph_index < len(paragraphs) and
shift + paragraphs[next_paragraph_index].size <= self._stride):
shift += paragraphs[next_paragraph_index].size
next_paragraph_index += 1
return next_paragraph_index, Features(
id='{}--{}'.format(question_answer_evidence.question.id,
question_answer_evidence.evidence.info.id),
stride_index=stride_index,
question_id=question_answer_evidence.question.id,
question=question_answer_evidence.question.value,
context=context,
token_ids=ids + token_ids,
global_token_ids=global_ids,
segment_ids=segment_ids,
token_offsets=offsets)
def process(
self, question_answer_evidence: QuestionAnswerEvidence
) -> Generator[Features, None, None]:
# Tokenize question which is shared among all examples.
ids = (
self._sentencepiece_processor.EncodeAsIds(
question_answer_evidence.question.value) +
[self._sentencepiece_processor.PieceToId(_SEP_PIECE)])
paragraph_texts = list(
filter(
lambda p: p,
map(lambda p: p.strip(),
question_answer_evidence.evidence.text.split('\n'))))
stride_index, paragraphs, paragraph_index = 0, [], 0
while paragraph_index < len(paragraph_texts):
paragraph_index, features = self._make_features(stride_index,
paragraph_texts,
paragraphs,
question_answer_evidence,
ids, paragraph_index)
stride_index += 1
yield features
def _handle_exceptional_examples(
features: Features,
processor: spm.SentencePieceProcessor) -> List[AnswerSpan]:
"""Special cases in data."""
if features.id == 'qw_6687--Viola.txt':
pattern = 'three strings in common—G, D, and A'.encode('utf-8')
i = features.context.find(pattern)
if i != -1:
span = AnswerSpan(i + len(pattern) - 1, i + len(pattern), 'A')
span = realign_answer_span(features, None, processor, span)
assert span is not None, 'Span should exist.'
return [span]
if features.id == 'sfq_26183--Vitamin_A.txt':
pattern = ('Vitamin A is a group of unsaturated nutritional organic '
'compounds that includes retinol').encode('utf-8')
i = features.context.find(pattern)
if i != -1:
span = AnswerSpan(i + pattern.find(b'A'), i + pattern.find(b'A') + 1, 'A')
span = realign_answer_span(features, None, processor, span)
assert span is not None, 'Span should exist.'
spans = [span]
span = AnswerSpan(i, i + pattern.find(b'A') + 1, 'Vitamin A')
span = realign_answer_span(features, None, processor, span)
return spans + [span]
if features.id == 'odql_292--Colombia.txt':
pattern = b'Colombia is the third-most populous country in Latin America'
i = features.context.find(pattern)
if i != -1:
span = AnswerSpan(i, i + len(b'Colombia'), 'Colombia')
span = realign_answer_span(features, None, processor, span)
assert span is not None, 'Span should exist.'
return [span]
if features.id == 'tc_1648--Vietnam.txt':
pattern = 'Bảo Đại'.encode('utf-8')
i = features.context.find(pattern)
if i != -1:
span = AnswerSpan(i, i + len(pattern), 'Bảo Đại')
span = realign_answer_span(features, None, processor, span)
assert span is not None, 'Span should exist.'
return [span]
if features.id == 'sfq_22225--Irish_mythology.txt':
pattern = 'Tír na nÓg'.encode('utf-8')
spans = []
i = 0
while features.context.find(pattern, i) != -1:
i = features.context.find(pattern)
span = AnswerSpan(i, i + len(pattern), 'Tír na nÓg')
span = realign_answer_span(features, None, processor, span)
assert span is not None, 'Span should exist.'
spans.append(span)
i += len(pattern)
return spans
return []
class FindAnswerSpans(beam.DoFn):
"""Find answer spans in document."""
def __init__(self, sentencepiece_model_path: Text):
self._sentencepiece_model_path = sentencepiece_model_path
def setup(self):
self._sentencepiece_processor = read_sentencepiece_model(
self._sentencepiece_model_path)
def process(
self,
element: Tuple[Text, List[Features]],
answer_sets: Dict[Text, Set[Text]],
) -> Generator[Tuple[Features, List[AnswerSpan]], None, None]:
question_id, features = element
answer_set = answer_sets[question_id]
has_answer = False
for feature in features:
answer_spans = []
for answer_span in find_answer_spans(feature.context, answer_set):
realigned_answer_span = realign_answer_span(
feature, answer_set, self._sentencepiece_processor, answer_span)
if realigned_answer_span:
answer_spans.append(realigned_answer_span)
if not answer_spans:
answer_spans = _handle_exceptional_examples(
feature, self._sentencepiece_processor)
if answer_spans:
has_answer = True
else:
metrics.Metrics.counter('_', 'answerless_examples').inc()
yield feature, answer_spans
if not has_answer:
metrics.Metrics.counter('_', 'answerless_questions').inc()
logging.error('Question %s has no answer.', question_id)
def make_example(
features: Features,
labels: Optional[List[AnswerSpan]] = None) -> Tuple[Text, Dict[Text, Any]]:
"""Make an example."""
feature = {
'id': features.id,
'qid': features.question_id,
'question': features.question,
'context': features.context,
'token_ids': features.token_ids,
'token_offsets': features.token_offsets,
'segment_ids': features.segment_ids,
'global_token_ids': features.global_token_ids,
}
if labels:
answers = set((label.begin, label.end) for label in labels)
feature['answers'] = np.array([list(answer) for answer in answers],
np.int64)
else:
feature['answers'] = np.zeros([0, 2], np.int64)
metrics.Metrics.counter('_', 'examples').inc()
return f'{features.id}--{features.stride_index}', feature
def make_pipeline(root: beam.Pipeline, question_answers: List[QuestionAnswer],
answer: bool, max_num_tokens: int, max_num_global_tokens: int,
stride: int, sentencepiece_model_path: Text,
wikipedia_dir: Text, web_dir: Text):
"""Makes a Beam pipeline."""
question_answers = (
root | 'CreateQuestionAnswers' >> beam.Create(question_answers))
features = (
question_answers
| 'ReadEvidence' >> beam.ParDo(
ReadEvidence(wikipedia_dir=wikipedia_dir, web_dir=web_dir))
| 'MakeFeatures' >> beam.ParDo(
MakeFeatures(
sentencepiece_model_path=sentencepiece_model_path,
max_num_tokens=max_num_tokens,
max_num_global_tokens=max_num_global_tokens,
stride=stride)))
if answer:
features = features | 'KeyFeature' >> beam.Map(
lambda feature: (feature.question_id, feature))
# pylint: disable=g-long-lambda
answer_sets = (
question_answers
| 'MakeAnswerSet' >>
beam.Map(lambda qa: (qa.question.id, make_answer_set(qa.answer))))
# pylint: enable=g-long-lambda
examples = (
features
| beam.GroupByKey()
| 'FindAnswerSpans' >> beam.ParDo(
FindAnswerSpans(sentencepiece_model_path),
answer_sets=beam.pvalue.AsDict(answer_sets))
| 'MakeExamplesWithLabels' >> beam.MapTuple(make_example))
else:
examples = features | 'MakeExamples' >> beam.Map(make_example)
return examples
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# pylint: disable=bad-continuation
# pylint: disable=protected-access
# Generated by the protocol buffer compiler. DO NOT EDIT!
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='third_party/sentencepiece/src/sentencepiece.proto',
package='sentencepiece',
syntax='proto2',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n1third_party/sentencepiece/src/sentencepiece.proto\x12\rsentencepiece\"\xdf\x01\n\x11SentencePieceText\x12\x0c\n\x04text\x18\x01 \x01(\t\x12>\n\x06pieces\x18\x02 \x03(\x0b\x32..sentencepiece.SentencePieceText.SentencePiece\x12\r\n\x05score\x18\x03 \x01(\x02\x1a\x62\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\r\x12\x0f\n\x07surface\x18\x03 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x04 \x01(\r\x12\x0b\n\x03\x65nd\x18\x05 \x01(\r*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"J\n\x16NBestSentencePieceText\x12\x30\n\x06nbests\x18\x01 \x03(\x0b\x32 .sentencepiece.SentencePieceText'
)
_SENTENCEPIECETEXT_SENTENCEPIECE = _descriptor.Descriptor(
name='SentencePiece',
full_name='sentencepiece.SentencePieceText.SentencePiece',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='piece',
full_name='sentencepiece.SentencePieceText.SentencePiece.piece',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b''.decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='id',
full_name='sentencepiece.SentencePieceText.SentencePiece.id',
index=1,
number=2,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='surface',
full_name='sentencepiece.SentencePieceText.SentencePiece.surface',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b''.decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='begin',
full_name='sentencepiece.SentencePieceText.SentencePiece.begin',
index=3,
number=4,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='end',
full_name='sentencepiece.SentencePieceText.SentencePiece.end',
index=4,
number=5,
type=13,
cpp_type=3,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=True,
syntax='proto2',
extension_ranges=[
(200, 536870912),
],
oneofs=[],
serialized_start=183,
serialized_end=281,
)
_SENTENCEPIECETEXT = _descriptor.Descriptor(
name='SentencePieceText',
full_name='sentencepiece.SentencePieceText',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='text',
full_name='sentencepiece.SentencePieceText.text',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b''.decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='pieces',
full_name='sentencepiece.SentencePieceText.pieces',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='score',
full_name='sentencepiece.SentencePieceText.score',
index=2,
number=3,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
],
extensions=[],
nested_types=[
_SENTENCEPIECETEXT_SENTENCEPIECE,
],
enum_types=[],
serialized_options=None,
is_extendable=True,
syntax='proto2',
extension_ranges=[
(200, 536870912),
],
oneofs=[],
serialized_start=69,
serialized_end=292,
)
_NBESTSENTENCEPIECETEXT = _descriptor.Descriptor(
name='NBestSentencePieceText',
full_name='sentencepiece.NBestSentencePieceText',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='nbests',
full_name='sentencepiece.NBestSentencePieceText.nbests',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=294,
serialized_end=368,
)
_SENTENCEPIECETEXT_SENTENCEPIECE.containing_type = _SENTENCEPIECETEXT
_SENTENCEPIECETEXT.fields_by_name[
'pieces'].message_type = _SENTENCEPIECETEXT_SENTENCEPIECE
_NBESTSENTENCEPIECETEXT.fields_by_name[
'nbests'].message_type = _SENTENCEPIECETEXT
DESCRIPTOR.message_types_by_name['SentencePieceText'] = _SENTENCEPIECETEXT
DESCRIPTOR.message_types_by_name[
'NBestSentencePieceText'] = _NBESTSENTENCEPIECETEXT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
SentencePieceText = _reflection.GeneratedProtocolMessageType(
'SentencePieceText',
(_message.Message,),
{
'SentencePiece':
_reflection.GeneratedProtocolMessageType(
'SentencePiece',
(_message.Message,),
{
'DESCRIPTOR':
_SENTENCEPIECETEXT_SENTENCEPIECE,
'__module__':
'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.SentencePieceText.SentencePiece)
}),
'DESCRIPTOR':
_SENTENCEPIECETEXT,
'__module__':
'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.SentencePieceText)
})
_sym_db.RegisterMessage(SentencePieceText)
_sym_db.RegisterMessage(SentencePieceText.SentencePiece)
NBestSentencePieceText = _reflection.GeneratedProtocolMessageType(
'NBestSentencePieceText',
(_message.Message,),
{
'DESCRIPTOR': _NBESTSENTENCEPIECETEXT,
'__module__': 'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.NBestSentencePieceText)
})
_sym_db.RegisterMessage(NBestSentencePieceText)
# @@protoc_insertion_point(module_scope)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA training script."""
import collections
import contextlib
import functools
import json
import operator
import os
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
import tensorflow_datasets as tfds
import sentencepiece as spm
from official.nlp import optimization as nlp_optimization
from official.nlp.configs import encoders
from official.nlp.projects.triviaqa import evaluation
from official.nlp.projects.triviaqa import inputs
from official.nlp.projects.triviaqa import modeling
from official.nlp.projects.triviaqa import prediction
flags.DEFINE_string('data_dir', None, 'Data directory for TensorFlow Datasets.')
flags.DEFINE_string(
'validation_gold_path', None,
'Path to golden validation. Usually, the wikipedia-dev.json file.')
flags.DEFINE_string('model_dir', None,
'Directory for checkpoints and summaries.')
flags.DEFINE_string('model_config_path', None,
'JSON file containing model coniguration.')
flags.DEFINE_string('sentencepiece_model_path', None,
'Path to sentence piece model.')
flags.DEFINE_enum('encoder', 'bigbird',
['bert', 'bigbird', 'albert', 'mobilebert'],
'Which transformer encoder model to use.')
flags.DEFINE_integer('bigbird_block_size', 64,
'Size of blocks for sparse block attention.')
flags.DEFINE_string('init_checkpoint_path', None,
'Path from which to initialize weights.')
flags.DEFINE_integer('train_sequence_length', 4096,
'Maximum number of tokens for training.')
flags.DEFINE_integer('train_global_sequence_length', 320,
'Maximum number of global tokens for training.')
flags.DEFINE_integer('validation_sequence_length', 4096,
'Maximum number of tokens for validation.')
flags.DEFINE_integer('validation_global_sequence_length', 320,
'Maximum number of global tokens for validation.')
flags.DEFINE_integer('batch_size', 32, 'Size of batch.')
flags.DEFINE_string('master', '', 'Address of the TPU master.')
flags.DEFINE_integer('decode_top_k', 8,
'Maximum number of tokens to consider for begin/end.')
flags.DEFINE_integer('decode_max_size', 16,
'Maximum number of sentence pieces in an answer.')
flags.DEFINE_float('dropout_rate', 0.1, 'Dropout rate for hidden layers.')
flags.DEFINE_float('attention_dropout_rate', 0.3,
'Dropout rate for attention layers.')
flags.DEFINE_float('label_smoothing', 1e-1, 'Degree of label smoothing.')
flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files')
FLAGS = flags.FLAGS
@contextlib.contextmanager
def worker_context():
if FLAGS.master:
with tf.device('/job:worker') as d:
yield d
else:
yield
def read_sentencepiece_model(path):
with tf.io.gfile.GFile(path, 'rb') as file:
processor = spm.SentencePieceProcessor()
processor.LoadFromSerializedProto(file.read())
return processor
# Rename old BERT v1 configuration parameters.
_MODEL_CONFIG_REPLACEMENTS = {
'num_hidden_layers': 'num_layers',
'attention_probs_dropout_prob': 'attention_dropout_rate',
'hidden_dropout_prob': 'dropout_rate',
'hidden_act': 'hidden_activation',
'window_size': 'block_size',
}
def read_model_config(encoder,
path,
bigbird_block_size=None) -> encoders.EncoderConfig:
"""Merges the JSON configuration into the encoder configuration."""
with tf.io.gfile.GFile(path) as f:
model_config = json.load(f)
for key, value in _MODEL_CONFIG_REPLACEMENTS.items():
if key in model_config:
model_config[value] = model_config.pop(key)
model_config['attention_dropout_rate'] = FLAGS.attention_dropout_rate
model_config['dropout_rate'] = FLAGS.dropout_rate
model_config['block_size'] = bigbird_block_size
encoder_config = encoders.EncoderConfig(type=encoder)
# Override the default config with those loaded from the JSON file.
encoder_config_keys = encoder_config.get().as_dict().keys()
overrides = {}
for key, value in model_config.items():
if key in encoder_config_keys:
overrides[key] = value
else:
logging.warning('Ignoring config parameter %s=%s', key, value)
encoder_config.get().override(overrides)
return encoder_config
@gin.configurable(blacklist=[
'model',
'strategy',
'train_dataset',
'model_dir',
'init_checkpoint_path',
'evaluate_fn',
])
def fit(model,
strategy,
train_dataset,
model_dir,
init_checkpoint_path=None,
evaluate_fn=None,
learning_rate=1e-5,
learning_rate_polynomial_decay_rate=1.,
weight_decay_rate=1e-1,
num_warmup_steps=5000,
num_decay_steps=51000,
num_epochs=6):
"""Train and evaluate."""
hparams = dict(
learning_rate=learning_rate,
num_decay_steps=num_decay_steps,
num_warmup_steps=num_warmup_steps,
num_epochs=num_epochs,
weight_decay_rate=weight_decay_rate,
dropout_rate=FLAGS.dropout_rate,
attention_dropout_rate=FLAGS.attention_dropout_rate,
label_smoothing=FLAGS.label_smoothing)
logging.info(hparams)
learning_rate_schedule = nlp_optimization.WarmUp(
learning_rate,
tf.keras.optimizers.schedules.PolynomialDecay(
learning_rate,
num_decay_steps,
end_learning_rate=0.,
power=learning_rate_polynomial_decay_rate), num_warmup_steps)
with strategy.scope():
optimizer = nlp_optimization.AdamWeightDecay(
learning_rate_schedule,
weight_decay_rate=weight_decay_rate,
epsilon=1e-6,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
model.compile(optimizer, loss=modeling.SpanOrCrossEntropyLoss())
def init_fn(init_checkpoint_path):
ckpt = tf.train.Checkpoint(encoder=model.encoder)
ckpt.restore(init_checkpoint_path).assert_existing_objects_matched()
with worker_context():
ckpt_manager = tf.train.CheckpointManager(
tf.train.Checkpoint(model=model, optimizer=optimizer),
model_dir,
max_to_keep=None,
init_fn=(functools.partial(init_fn, init_checkpoint_path)
if init_checkpoint_path else None))
with strategy.scope():
ckpt_manager.restore_or_initialize()
val_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'val'))
best_exact_match = 0.
for epoch in range(len(ckpt_manager.checkpoints), num_epochs):
model.fit(
train_dataset,
callbacks=[
tf.keras.callbacks.TensorBoard(model_dir, write_graph=False),
])
ckpt_path = ckpt_manager.save()
if evaluate_fn is None:
continue
metrics = evaluate_fn()
logging.info('Epoch %d: %s', epoch + 1, metrics)
if best_exact_match < metrics['exact_match']:
best_exact_match = metrics['exact_match']
model.save(os.path.join(model_dir, 'export'), include_optimizer=False)
logging.info('Exporting %s as SavedModel.', ckpt_path)
with val_summary_writer.as_default():
for name, data in metrics.items():
tf.summary.scalar(name, data, epoch + 1)
def evaluate(sp_processor, features_map_fn, labels_map_fn, logits_fn,
decode_logits_fn, split_and_pad_fn, distribute_strategy,
validation_dataset, ground_truth):
"""Run evaluation."""
loss_metric = tf.keras.metrics.Mean()
@tf.function
def update_loss(y, logits):
loss_fn = modeling.SpanOrCrossEntropyLoss(
reduction=tf.keras.losses.Reduction.NONE)
return loss_metric(loss_fn(y, logits))
predictions = collections.defaultdict(list)
for _, (features, labels) in validation_dataset.enumerate():
token_ids = features['token_ids']
y = labels_map_fn(token_ids, labels)
x = split_and_pad_fn(features_map_fn(features))
logits = tf.concat(
distribute_strategy.experimental_local_results(logits_fn(x)), 0)
logits = logits[:features['token_ids'].shape[0]]
update_loss(y, logits)
end_limit = token_ids.row_lengths() - 1 # inclusive
begin, end, scores = decode_logits_fn(logits, end_limit)
answers = prediction.decode_answer(features['context'], begin, end,
features['token_offsets'],
end_limit).numpy()
for _, (qid, token_id, offset, score, answer) in enumerate(
zip(features['qid'].numpy(),
tf.gather(features['token_ids'], begin, batch_dims=1).numpy(),
tf.gather(features['token_offsets'], begin, batch_dims=1).numpy(),
scores, answers)):
if not answer:
continue
if sp_processor.IdToPiece(int(token_id)).startswith('▁') and offset > 0:
answer = answer[1:]
predictions[qid.decode('utf-8')].append((score, answer.decode('utf-8')))
predictions = {
qid: evaluation.normalize_answer(
sorted(answers, key=operator.itemgetter(0), reverse=True)[0][1])
for qid, answers in predictions.items()
}
metrics = evaluation.evaluate_triviaqa(ground_truth, predictions, mute=True)
metrics['loss'] = loss_metric.result().numpy()
return metrics
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
gin.parse_config(FLAGS.gin_bindings)
model_config = read_model_config(
FLAGS.encoder,
FLAGS.model_config_path,
bigbird_block_size=FLAGS.bigbird_block_size)
logging.info(model_config.get().as_dict())
# Configure input processing.
sp_processor = read_sentencepiece_model(FLAGS.sentencepiece_model_path)
features_map_fn = functools.partial(
inputs.features_map_fn,
local_radius=FLAGS.bigbird_block_size,
relative_pos_max_distance=24,
use_hard_g2l_mask=True,
padding_id=sp_processor.PieceToId('<pad>'),
eos_id=sp_processor.PieceToId('</s>'),
null_id=sp_processor.PieceToId('<empty>'),
cls_id=sp_processor.PieceToId('<ans>'),
sep_id=sp_processor.PieceToId('<sep_0>'))
train_features_map_fn = tf.function(
functools.partial(
features_map_fn,
sequence_length=FLAGS.train_sequence_length,
global_sequence_length=FLAGS.train_global_sequence_length),
autograph=False)
train_labels_map_fn = tf.function(
functools.partial(
inputs.labels_map_fn, sequence_length=FLAGS.train_sequence_length))
# Connect to TPU cluster.
if FLAGS.master:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
else:
strategy = tf.distribute.MirroredStrategy()
# Initialize datasets.
with worker_context():
_ = tf.random.get_global_generator()
train_dataset = inputs.read_batches(
FLAGS.data_dir,
tfds.Split.TRAIN,
FLAGS.batch_size,
shuffle=True,
drop_final_batch=True)
validation_dataset = inputs.read_batches(FLAGS.data_dir,
tfds.Split.VALIDATION,
FLAGS.batch_size)
def train_map_fn(x, y):
features = train_features_map_fn(x)
labels = modeling.smooth_labels(FLAGS.label_smoothing,
train_labels_map_fn(x['token_ids'], y),
features['question_lengths'],
features['token_ids'])
return features, labels
train_dataset = train_dataset.map(train_map_fn, 16).prefetch(16)
# Initialize model and compile.
with strategy.scope():
model = modeling.TriviaQaModel(model_config, FLAGS.train_sequence_length)
logits_fn = tf.function(
functools.partial(prediction.distributed_logits_fn, model))
decode_logits_fn = tf.function(
functools.partial(prediction.decode_logits, FLAGS.decode_top_k,
FLAGS.decode_max_size))
split_and_pad_fn = tf.function(
functools.partial(prediction.split_and_pad, strategy, FLAGS.batch_size))
# Evaluation strategy.
with tf.io.gfile.GFile(FLAGS.validation_gold_path) as f:
ground_truth = {
datum['QuestionId']: datum['Answer'] for datum in json.load(f)['Data']
}
validation_features_map_fn = tf.function(
functools.partial(
features_map_fn,
sequence_length=FLAGS.validation_sequence_length,
global_sequence_length=FLAGS.validation_global_sequence_length),
autograph=False)
validation_labels_map_fn = tf.function(
functools.partial(
inputs.labels_map_fn,
sequence_length=FLAGS.validation_sequence_length))
evaluate_fn = functools.partial(
evaluate,
sp_processor=sp_processor,
features_map_fn=validation_features_map_fn,
labels_map_fn=validation_labels_map_fn,
logits_fn=logits_fn,
decode_logits_fn=decode_logits_fn,
split_and_pad_fn=split_and_pad_fn,
distribute_strategy=strategy,
validation_dataset=validation_dataset,
ground_truth=ground_truth)
logging.info('Model initialized. Beginning training fit loop.')
fit(model, strategy, train_dataset, FLAGS.model_dir,
FLAGS.init_checkpoint_path, evaluate_fn)
if __name__ == '__main__':
flags.mark_flags_as_required([
'model_config_path', 'model_dir', 'sentencepiece_model_path',
'validation_gold_path'
])
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment