Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0ab5dcbf
Commit
0ab5dcbf
authored
Oct 01, 2020
by
Philip Pham
Committed by
A. Unique TensorFlower
Oct 01, 2020
Browse files
Add TriviaQA Task to projects
PiperOrigin-RevId: 334950562
parent
ec955c21
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2863 additions
and
0 deletions
+2863
-0
official/nlp/projects/triviaqa/dataset.py
official/nlp/projects/triviaqa/dataset.py
+455
-0
official/nlp/projects/triviaqa/download_and_prepare.py
official/nlp/projects/triviaqa/download_and_prepare.py
+71
-0
official/nlp/projects/triviaqa/evaluate.py
official/nlp/projects/triviaqa/evaluate.py
+47
-0
official/nlp/projects/triviaqa/evaluation.py
official/nlp/projects/triviaqa/evaluation.py
+169
-0
official/nlp/projects/triviaqa/inputs.py
official/nlp/projects/triviaqa/inputs.py
+548
-0
official/nlp/projects/triviaqa/modeling.py
official/nlp/projects/triviaqa/modeling.py
+113
-0
official/nlp/projects/triviaqa/predict.py
official/nlp/projects/triviaqa/predict.py
+184
-0
official/nlp/projects/triviaqa/prediction.py
official/nlp/projects/triviaqa/prediction.py
+68
-0
official/nlp/projects/triviaqa/preprocess.py
official/nlp/projects/triviaqa/preprocess.py
+514
-0
official/nlp/projects/triviaqa/sentencepiece_pb2.py
official/nlp/projects/triviaqa/sentencepiece_pb2.py
+311
-0
official/nlp/projects/triviaqa/train.py
official/nlp/projects/triviaqa/train.py
+383
-0
No files found.
official/nlp/projects/triviaqa/dataset.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA: A Reading Comprehension Dataset."""
import
functools
import
json
import
os
from
absl
import
logging
import
apache_beam
as
beam
import
six
import
tensorflow
as
tf
import
tensorflow_datasets.public_api
as
tfds
from
official.nlp.projects.triviaqa
import
preprocess
_CITATION
=
"""
@article{2017arXivtriviaqa,
author = {{Joshi}, Mandar and {Choi}, Eunsol and {Weld},
Daniel and {Zettlemoyer}, Luke},
title = "{triviaqa: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension}",
journal = {arXiv e-prints},
year = 2017,
eid = {arXiv:1705.03551},
pages = {arXiv:1705.03551},
archivePrefix = {arXiv},
eprint = {1705.03551},
}
"""
_DOWNLOAD_URL_TMPL
=
(
"http://nlp.cs.washington.edu/triviaqa/data/triviaqa-{}.tar.gz"
)
_TRAIN_FILE_FORMAT
=
"*-train.json"
_VALIDATION_FILE_FORMAT
=
"*-dev.json"
_TEST_FILE_FORMAT
=
"*test-without-answers.json"
_WEB_EVIDENCE_DIR
=
"evidence/web"
_WIKI_EVIDENCE_DIR
=
"evidence/wikipedia"
_DESCRIPTION
=
"""
\
TriviaqQA is a reading comprehension dataset containing over 650K
question-answer-evidence triples. TriviaqQA includes 95K question-answer
pairs authored by trivia enthusiasts and independently gathered evidence
documents, six per question on average, that provide high quality distant
supervision for answering the questions.
"""
_RC_DESCRIPTION
=
"""
\
Question-answer pairs where all documents for a given question contain the
answer string(s).
"""
_UNFILTERED_DESCRIPTION
=
"""
\
110k question-answer pairs for open domain QA where not all documents for a
given question contain the answer string(s). This makes the unfiltered dataset
more appropriate for IR-style QA.
"""
_CONTEXT_ADDENDUM
=
"Includes context from Wikipedia and search results."
def
_web_evidence_dir
(
tmp_dir
):
return
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
tmp_dir
,
_WEB_EVIDENCE_DIR
))
def
_wiki_evidence_dir
(
tmp_dir
):
return
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
tmp_dir
,
_WIKI_EVIDENCE_DIR
))
class
TriviaQAConfig
(
tfds
.
core
.
BuilderConfig
):
"""BuilderConfig for TriviaQA."""
def
__init__
(
self
,
*
,
unfiltered
=
False
,
exclude_context
=
False
,
**
kwargs
):
"""BuilderConfig for TriviaQA.
Args:
unfiltered: bool, whether to use the unfiltered version of the dataset,
intended for open-domain QA.
exclude_context: bool, whether to exclude Wikipedia and search context for
reduced size.
**kwargs: keyword arguments forwarded to super.
"""
name
=
"unfiltered"
if
unfiltered
else
"rc"
if
exclude_context
:
name
+=
".nocontext"
description
=
_UNFILTERED_DESCRIPTION
if
unfiltered
else
_RC_DESCRIPTION
if
not
exclude_context
:
description
+=
_CONTEXT_ADDENDUM
super
(
TriviaQAConfig
,
self
).
__init__
(
name
=
name
,
description
=
description
,
version
=
tfds
.
core
.
Version
(
"1.1.1"
),
**
kwargs
)
self
.
unfiltered
=
unfiltered
self
.
exclude_context
=
exclude_context
class
BigBirdTriviaQAConfig
(
tfds
.
core
.
BuilderConfig
):
"""BuilderConfig for TriviaQA."""
def
__init__
(
self
,
**
kwargs
):
"""BuilderConfig for TriviaQA.
Args:
**kwargs: keyword arguments forwarded to super.
"""
name
=
"rc_wiki.preprocessed"
description
=
_RC_DESCRIPTION
super
(
BigBirdTriviaQAConfig
,
self
).
__init__
(
name
=
name
,
description
=
description
,
version
=
tfds
.
core
.
Version
(
"1.1.1"
),
**
kwargs
)
self
.
unfiltered
=
False
self
.
exclude_context
=
False
def
configure
(
self
,
sentencepiece_model_path
,
sequence_length
,
stride
,
global_sequence_length
=
None
):
"""Configures additional user-specified arguments."""
self
.
sentencepiece_model_path
=
sentencepiece_model_path
self
.
sequence_length
=
sequence_length
self
.
stride
=
stride
if
global_sequence_length
is
None
and
sequence_length
is
not
None
:
self
.
global_sequence_length
=
sequence_length
//
16
+
64
else
:
self
.
global_sequence_length
=
global_sequence_length
logging
.
info
(
"""
global_sequence_length: %s
sequence_length: %s
stride: %s
sentencepiece_model_path: %s"""
,
self
.
global_sequence_length
,
self
.
sequence_length
,
self
.
stride
,
self
.
sentencepiece_model_path
)
def
validate
(
self
):
"""Validates that user specifies valid arguments."""
if
self
.
sequence_length
is
None
:
raise
ValueError
(
"sequence_length must be specified for BigBird."
)
if
self
.
stride
is
None
:
raise
ValueError
(
"stride must be specified for BigBird."
)
if
self
.
sentencepiece_model_path
is
None
:
raise
ValueError
(
"sentencepiece_model_path must be specified for BigBird."
)
def
filter_files_for_big_bird
(
files
):
filtered_files
=
[
f
for
f
in
files
if
os
.
path
.
basename
(
f
).
startswith
(
"wiki"
)]
assert
len
(
filtered_files
)
==
1
,
"There should only be one wikipedia file."
return
filtered_files
class
TriviaQA
(
tfds
.
core
.
BeamBasedBuilder
):
"""TriviaQA is a reading comprehension dataset.
It containss over 650K question-answer-evidence triples.
"""
name
=
"bigbird_trivia_qa"
BUILDER_CONFIGS
=
[
BigBirdTriviaQAConfig
(),
TriviaQAConfig
(
unfiltered
=
False
,
exclude_context
=
False
),
# rc
TriviaQAConfig
(
unfiltered
=
False
,
exclude_context
=
True
),
# rc.nocontext
TriviaQAConfig
(
unfiltered
=
True
,
exclude_context
=
False
),
# unfiltered
TriviaQAConfig
(
unfiltered
=
True
,
exclude_context
=
True
),
# unfilered.nocontext
]
def
__init__
(
self
,
*
,
sentencepiece_model_path
=
None
,
sequence_length
=
None
,
stride
=
None
,
global_sequence_length
=
None
,
**
kwargs
):
super
(
TriviaQA
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
self
.
builder_config
,
BigBirdTriviaQAConfig
):
self
.
builder_config
.
configure
(
sentencepiece_model_path
=
sentencepiece_model_path
,
sequence_length
=
sequence_length
,
stride
=
stride
,
global_sequence_length
=
global_sequence_length
)
def
_info
(
self
):
if
isinstance
(
self
.
builder_config
,
BigBirdTriviaQAConfig
):
return
tfds
.
core
.
DatasetInfo
(
builder
=
self
,
description
=
_DESCRIPTION
,
supervised_keys
=
None
,
homepage
=
"http://nlp.cs.washington.edu/triviaqa/"
,
citation
=
_CITATION
,
features
=
tfds
.
features
.
FeaturesDict
({
"id"
:
tfds
.
features
.
Text
(),
"qid"
:
tfds
.
features
.
Text
(),
"question"
:
tfds
.
features
.
Text
(),
"context"
:
tfds
.
features
.
Text
(),
# Sequence features.
"token_ids"
:
tfds
.
features
.
Tensor
(
shape
=
(
None
,),
dtype
=
tf
.
int64
),
"token_offsets"
:
tfds
.
features
.
Tensor
(
shape
=
(
None
,),
dtype
=
tf
.
int64
),
"segment_ids"
:
tfds
.
features
.
Tensor
(
shape
=
(
None
,),
dtype
=
tf
.
int64
),
"global_token_ids"
:
tfds
.
features
.
Tensor
(
shape
=
(
None
,),
dtype
=
tf
.
int64
),
# Start and end indices (inclusive).
"answers"
:
tfds
.
features
.
Tensor
(
shape
=
(
None
,
2
),
dtype
=
tf
.
int64
),
}))
return
tfds
.
core
.
DatasetInfo
(
builder
=
self
,
description
=
_DESCRIPTION
,
features
=
tfds
.
features
.
FeaturesDict
({
"question"
:
tfds
.
features
.
Text
(),
"question_id"
:
tfds
.
features
.
Text
(),
"question_source"
:
tfds
.
features
.
Text
(),
"entity_pages"
:
tfds
.
features
.
Sequence
({
"doc_source"
:
tfds
.
features
.
Text
(),
"filename"
:
tfds
.
features
.
Text
(),
"title"
:
tfds
.
features
.
Text
(),
"wiki_context"
:
tfds
.
features
.
Text
(),
}),
"search_results"
:
tfds
.
features
.
Sequence
({
"description"
:
tfds
.
features
.
Text
(),
"filename"
:
tfds
.
features
.
Text
(),
"rank"
:
tf
.
int32
,
"title"
:
tfds
.
features
.
Text
(),
"url"
:
tfds
.
features
.
Text
(),
"search_context"
:
tfds
.
features
.
Text
(),
}),
"answer"
:
tfds
.
features
.
FeaturesDict
({
"aliases"
:
tfds
.
features
.
Sequence
(
tfds
.
features
.
Text
()),
"normalized_aliases"
:
tfds
.
features
.
Sequence
(
tfds
.
features
.
Text
()),
"matched_wiki_entity_name"
:
tfds
.
features
.
Text
(),
"normalized_matched_wiki_entity_name"
:
tfds
.
features
.
Text
(),
"normalized_value"
:
tfds
.
features
.
Text
(),
"type"
:
tfds
.
features
.
Text
(),
"value"
:
tfds
.
features
.
Text
(),
}),
}),
supervised_keys
=
None
,
homepage
=
"http://nlp.cs.washington.edu/triviaqa/"
,
citation
=
_CITATION
,
)
def
_split_generators
(
self
,
dl_manager
):
"""Returns SplitGenerators."""
cfg
=
self
.
builder_config
download_urls
=
dict
()
if
not
(
cfg
.
unfiltered
and
cfg
.
exclude_context
):
download_urls
[
"rc"
]
=
_DOWNLOAD_URL_TMPL
.
format
(
"rc"
)
if
cfg
.
unfiltered
:
download_urls
[
"unfiltered"
]
=
_DOWNLOAD_URL_TMPL
.
format
(
"unfiltered"
)
file_paths
=
dl_manager
.
download_and_extract
(
download_urls
)
qa_dir
=
(
os
.
path
.
join
(
file_paths
[
"unfiltered"
],
"triviaqa-unfiltered"
)
if
cfg
.
unfiltered
else
os
.
path
.
join
(
file_paths
[
"rc"
],
"qa"
))
train_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
qa_dir
,
_TRAIN_FILE_FORMAT
))
valid_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
qa_dir
,
_VALIDATION_FILE_FORMAT
))
test_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
qa_dir
,
_TEST_FILE_FORMAT
))
if
cfg
.
exclude_context
:
web_evidence_dir
=
None
wiki_evidence_dir
=
None
else
:
web_evidence_dir
=
os
.
path
.
join
(
file_paths
[
"rc"
],
_WEB_EVIDENCE_DIR
)
wiki_evidence_dir
=
os
.
path
.
join
(
file_paths
[
"rc"
],
_WIKI_EVIDENCE_DIR
)
if
isinstance
(
cfg
,
BigBirdTriviaQAConfig
):
train_files
=
filter_files_for_big_bird
(
train_files
)
valid_files
=
filter_files_for_big_bird
(
valid_files
)
test_files
=
filter_files_for_big_bird
(
test_files
)
return
[
tfds
.
core
.
SplitGenerator
(
name
=
tfds
.
Split
.
TRAIN
,
gen_kwargs
=
{
"files"
:
train_files
,
"web_dir"
:
web_evidence_dir
,
"wiki_dir"
:
wiki_evidence_dir
,
"answer"
:
True
}),
tfds
.
core
.
SplitGenerator
(
name
=
tfds
.
Split
.
VALIDATION
,
gen_kwargs
=
{
"files"
:
valid_files
,
"web_dir"
:
web_evidence_dir
,
"wiki_dir"
:
wiki_evidence_dir
,
"answer"
:
True
}),
tfds
.
core
.
SplitGenerator
(
name
=
tfds
.
Split
.
TEST
,
gen_kwargs
=
{
"files"
:
test_files
,
"web_dir"
:
web_evidence_dir
,
"wiki_dir"
:
wiki_evidence_dir
,
"answer"
:
False
}),
]
def
_build_pcollection
(
self
,
pipeline
,
files
,
web_dir
,
wiki_dir
,
answer
):
if
isinstance
(
self
.
builder_config
,
BigBirdTriviaQAConfig
):
self
.
builder_config
.
validate
()
question_answers
=
preprocess
.
read_question_answers
(
files
[
0
])
return
preprocess
.
make_pipeline
(
pipeline
,
question_answers
=
question_answers
,
answer
=
answer
,
max_num_tokens
=
self
.
builder_config
.
sequence_length
,
max_num_global_tokens
=
self
.
builder_config
.
global_sequence_length
,
stride
=
self
.
builder_config
.
stride
,
sentencepiece_model_path
=
self
.
builder_config
.
sentencepiece_model_path
,
wikipedia_dir
=
wiki_dir
,
web_dir
=
web_dir
)
parse_example_fn
=
functools
.
partial
(
parse_example
,
self
.
builder_config
.
exclude_context
,
web_dir
,
wiki_dir
)
return
(
pipeline
|
beam
.
Create
(
files
)
|
beam
.
ParDo
(
ReadQuestions
())
|
beam
.
Reshuffle
()
|
beam
.
Map
(
parse_example_fn
))
class
ReadQuestions
(
beam
.
DoFn
):
"""Read questions from JSON."""
def
process
(
self
,
file
):
with
tf
.
io
.
gfile
.
GFile
(
file
)
as
f
:
data
=
json
.
load
(
f
)
for
question
in
data
[
"Data"
]:
example
=
{
"SourceFile"
:
os
.
path
.
basename
(
file
)}
example
.
update
(
question
)
yield
example
def
parse_example
(
exclude_context
,
web_dir
,
wiki_dir
,
article
):
"""Return a single example from an article JSON record."""
def
_strip
(
collection
):
return
[
item
.
strip
()
for
item
in
collection
]
if
"Answer"
in
article
:
answer
=
article
[
"Answer"
]
answer_dict
=
{
"aliases"
:
_strip
(
answer
[
"Aliases"
]),
"normalized_aliases"
:
_strip
(
answer
[
"NormalizedAliases"
]),
"matched_wiki_entity_name"
:
answer
.
get
(
"MatchedWikiEntryName"
,
""
).
strip
(),
"normalized_matched_wiki_entity_name"
:
answer
.
get
(
"NormalizedMatchedWikiEntryName"
,
""
).
strip
(),
"normalized_value"
:
answer
[
"NormalizedValue"
].
strip
(),
"type"
:
answer
[
"Type"
].
strip
(),
"value"
:
answer
[
"Value"
].
strip
(),
}
else
:
answer_dict
=
{
"aliases"
:
[],
"normalized_aliases"
:
[],
"matched_wiki_entity_name"
:
"<unk>"
,
"normalized_matched_wiki_entity_name"
:
"<unk>"
,
"normalized_value"
:
"<unk>"
,
"type"
:
""
,
"value"
:
"<unk>"
,
}
if
exclude_context
:
article
[
"SearchResults"
]
=
[]
article
[
"EntityPages"
]
=
[]
def
_add_context
(
collection
,
context_field
,
file_dir
):
"""Adds context from file, or skips if file does not exist."""
new_items
=
[]
for
item
in
collection
:
if
"Filename"
not
in
item
:
logging
.
info
(
"Missing context 'Filename', skipping."
)
continue
new_item
=
item
.
copy
()
fname
=
item
[
"Filename"
]
try
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
file_dir
,
fname
))
as
f
:
new_item
[
context_field
]
=
f
.
read
()
except
(
IOError
,
tf
.
errors
.
NotFoundError
):
logging
.
info
(
"File does not exist, skipping: %s"
,
fname
)
continue
new_items
.
append
(
new_item
)
return
new_items
def
_strip_if_str
(
v
):
return
v
.
strip
()
if
isinstance
(
v
,
six
.
string_types
)
else
v
def
_transpose_and_strip_dicts
(
dicts
,
field_names
):
return
{
tfds
.
core
.
naming
.
camelcase_to_snakecase
(
k
):
[
_strip_if_str
(
d
[
k
])
for
d
in
dicts
]
for
k
in
field_names
}
search_results
=
_transpose_and_strip_dicts
(
_add_context
(
article
.
get
(
"SearchResults"
,
[]),
"SearchContext"
,
web_dir
),
[
"Description"
,
"Filename"
,
"Rank"
,
"Title"
,
"Url"
,
"SearchContext"
])
entity_pages
=
_transpose_and_strip_dicts
(
_add_context
(
article
.
get
(
"EntityPages"
,
[]),
"WikiContext"
,
wiki_dir
),
[
"DocSource"
,
"Filename"
,
"Title"
,
"WikiContext"
])
question
=
article
[
"Question"
].
strip
()
question_id
=
article
[
"QuestionId"
]
question_source
=
article
[
"QuestionSource"
].
strip
()
return
f
"
{
article
[
'SourceFile'
]
}
_
{
question_id
}
"
,
{
"entity_pages"
:
entity_pages
,
"search_results"
:
search_results
,
"question"
:
question
,
"question_id"
:
question_id
,
"question_source"
:
question_source
,
"answer"
:
answer_dict
,
}
official/nlp/projects/triviaqa/download_and_prepare.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Downloads and prepares TriviaQA dataset."""
from
unittest
import
mock
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
apache_beam
as
beam
import
tensorflow_datasets
as
tfds
from
official.nlp.projects.triviaqa
import
dataset
# pylint: disable=unused-import
flags
.
DEFINE_integer
(
'sequence_length'
,
4096
,
'Max number of tokens.'
)
flags
.
DEFINE_integer
(
'global_sequence_length'
,
None
,
'Max number of question tokens plus sentences. If not set, defaults to '
'sequence_length // 16 + 64.'
)
flags
.
DEFINE_integer
(
'stride'
,
3072
,
'For documents longer than `sequence_length`, where to split them.'
)
flags
.
DEFINE_string
(
'sentencepiece_model_path'
,
None
,
'SentencePiece model to use for tokenization.'
)
flags
.
DEFINE_string
(
'data_dir'
,
None
,
'Data directory for TFDS.'
)
flags
.
DEFINE_string
(
'runner'
,
'DirectRunner'
,
'Beam runner to use.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
builder
=
tfds
.
builder
(
'bigbird_trivia_qa/rc_wiki.preprocessed'
,
data_dir
=
FLAGS
.
data_dir
,
sentencepiece_model_path
=
FLAGS
.
sentencepiece_model_path
,
sequence_length
=
FLAGS
.
sequence_length
,
global_sequence_length
=
FLAGS
.
global_sequence_length
,
stride
=
FLAGS
.
stride
)
download_config
=
tfds
.
download
.
DownloadConfig
(
beam_options
=
beam
.
options
.
pipeline_options
.
PipelineOptions
(
flags
=
[
f
'--runner=
{
FLAGS
.
runner
}
'
,
'--direct_num_workers=8'
,
'--direct_running_mode=multi_processing'
,
]))
with
mock
.
patch
(
'tensorflow_datasets.core.download.extractor._normpath'
,
new
=
lambda
x
:
x
):
builder
.
download_and_prepare
(
download_config
=
download_config
)
logging
.
info
(
builder
.
info
.
splits
)
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'sentencepiece_model_path'
)
app
.
run
(
main
)
official/nlp/projects/triviaqa/evaluate.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evalutes TriviaQA predictions."""
import
json
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.projects.triviaqa
import
evaluation
flags
.
DEFINE_string
(
'gold_path'
,
None
,
'Path to golden validation, i.e. wikipedia-dev.json.'
)
flags
.
DEFINE_string
(
'predictions_path'
,
None
,
'Path to predictions in JSON format'
)
FLAGS
=
flags
.
FLAGS
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
gold_path
)
as
f
:
ground_truth
=
{
datum
[
'QuestionId'
]:
datum
[
'Answer'
]
for
datum
in
json
.
load
(
f
)[
'Data'
]
}
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
predictions_path
)
as
f
:
predictions
=
json
.
load
(
f
)
logging
.
info
(
evaluation
.
evaluate_triviaqa
(
ground_truth
,
predictions
))
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'predictions_path'
)
app
.
run
(
main
)
official/nlp/projects/triviaqa/evaluation.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 Google LLC
# Copyright 2017 Mandar Joshi (mandar90@cs.washington.edu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Official evaluation script for v1.0 of the TriviaQA dataset.
Forked from
https://github.com/mandarjoshi90/triviaqa/blob/master/evaluation/triviaqa_evaluation.py.
Modifications are removal of main function.
"""
import
collections
import
re
import
string
import
sys
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r
'\b(a|an|the)\b'
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
handle_punc
(
text
):
exclude
=
set
(
string
.
punctuation
+
''
.
join
([
u
'‘'
,
u
'’'
,
u
'´'
,
u
'`'
]))
return
''
.
join
(
ch
if
ch
not
in
exclude
else
' '
for
ch
in
text
)
def
lower
(
text
):
return
text
.
lower
()
def
replace_underscore
(
text
):
return
text
.
replace
(
'_'
,
' '
)
return
white_space_fix
(
remove_articles
(
handle_punc
(
lower
(
replace_underscore
(
s
))))).
strip
()
def
f1_score
(
prediction
,
ground_truth
):
prediction_tokens
=
normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
normalize_answer
(
ground_truth
).
split
()
common
=
(
collections
.
Counter
(
prediction_tokens
)
&
collections
.
Counter
(
ground_truth_tokens
))
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
exact_match_score
(
prediction
,
ground_truth
):
return
normalize_answer
(
prediction
)
==
normalize_answer
(
ground_truth
)
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
is_exact_match
(
answer_object
,
prediction
):
ground_truths
=
get_ground_truths
(
answer_object
)
for
ground_truth
in
ground_truths
:
if
exact_match_score
(
prediction
,
ground_truth
):
return
True
return
False
def
has_exact_match
(
ground_truths
,
candidates
):
for
ground_truth
in
ground_truths
:
if
ground_truth
in
candidates
:
return
True
return
False
def
get_ground_truths
(
answer
):
return
answer
[
'NormalizedAliases'
]
+
[
normalize_answer
(
ans
)
for
ans
in
answer
.
get
(
'HumanAnswers'
,
[])
]
def
get_oracle_score
(
ground_truth
,
predicted_answers
,
qid_list
=
None
,
mute
=
False
):
exact_match
=
common
=
0
if
qid_list
is
None
:
qid_list
=
ground_truth
.
keys
()
for
qid
in
qid_list
:
if
qid
not
in
predicted_answers
:
if
not
mute
:
message
=
'Irrelavant question {} will receive score 0.'
.
format
(
qid
)
print
(
message
,
file
=
sys
.
stderr
)
continue
common
+=
1
prediction
=
normalize_answer
(
predicted_answers
[
qid
])
ground_truths
=
get_ground_truths
(
ground_truth
[
qid
])
em_for_this_question
=
has_exact_match
(
ground_truths
,
prediction
)
exact_match
+=
int
(
em_for_this_question
)
exact_match
=
100.0
*
exact_match
/
len
(
qid_list
)
return
{
'oracle_exact_match'
:
exact_match
,
'common'
:
common
,
'denominator'
:
len
(
qid_list
),
'pred_len'
:
len
(
predicted_answers
),
'gold_len'
:
len
(
ground_truth
)
}
def
evaluate_triviaqa
(
ground_truth
,
predicted_answers
,
qid_list
=
None
,
mute
=
False
):
f1
=
exact_match
=
common
=
0
if
qid_list
is
None
:
qid_list
=
ground_truth
.
keys
()
for
qid
in
qid_list
:
if
qid
not
in
predicted_answers
:
if
not
mute
:
message
=
'Missed question {} will receive score 0.'
.
format
(
qid
)
print
(
message
,
file
=
sys
.
stderr
)
continue
if
qid
not
in
ground_truth
:
if
not
mute
:
message
=
'Irrelavant question {} will receive score 0.'
.
format
(
qid
)
print
(
message
,
file
=
sys
.
stderr
)
continue
common
+=
1
prediction
=
predicted_answers
[
qid
]
ground_truths
=
get_ground_truths
(
ground_truth
[
qid
])
em_for_this_question
=
metric_max_over_ground_truths
(
exact_match_score
,
prediction
,
ground_truths
)
if
em_for_this_question
==
0
and
not
mute
:
print
(
'em=0:'
,
prediction
,
ground_truths
)
exact_match
+=
em_for_this_question
f1_for_this_question
=
metric_max_over_ground_truths
(
f1_score
,
prediction
,
ground_truths
)
f1
+=
f1_for_this_question
exact_match
=
100.0
*
exact_match
/
len
(
qid_list
)
f1
=
100.0
*
f1
/
len
(
qid_list
)
return
{
'exact_match'
:
exact_match
,
'f1'
:
f1
,
'common'
:
common
,
'denominator'
:
len
(
qid_list
),
'pred_len'
:
len
(
predicted_answers
),
'gold_len'
:
len
(
ground_truth
)
}
official/nlp/projects/triviaqa/inputs.py
0 → 100644
View file @
0ab5dcbf
This diff is collapsed.
Click to expand it.
official/nlp/projects/triviaqa/modeling.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for TriviaQA."""
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.configs
import
encoders
class
TriviaQaHead
(
tf
.
keras
.
layers
.
Layer
):
"""Computes logits given token and global embeddings."""
def
__init__
(
self
,
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
'gelu'
),
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
**
kwargs
):
super
(
TriviaQaHead
,
self
).
__init__
(
**
kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
attention_dropout_rate
)
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
Dense
(
intermediate_size
)
self
.
_intermediate_activation
=
tf
.
keras
.
layers
.
Activation
(
intermediate_activation
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
dropout_rate
)
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
()
self
.
_logits_dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
def
build
(
self
,
input_shape
):
output_shape
=
input_shape
[
'token_embeddings'
][
-
1
]
self
.
_output_dense
=
tf
.
keras
.
layers
.
Dense
(
output_shape
)
super
(
TriviaQaHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
token_embeddings
=
inputs
[
'token_embeddings'
]
token_ids
=
inputs
[
'token_ids'
]
question_lengths
=
inputs
[
'question_lengths'
]
x
=
self
.
_attention_dropout
(
token_embeddings
,
training
=
training
)
intermediate_outputs
=
self
.
_intermediate_dense
(
x
)
intermediate_outputs
=
self
.
_intermediate_activation
(
intermediate_outputs
)
outputs
=
self
.
_output_dense
(
intermediate_outputs
)
outputs
=
self
.
_output_dropout
(
outputs
,
training
=
training
)
outputs
=
self
.
_output_layer_norm
(
outputs
+
token_embeddings
)
logits
=
self
.
_logits_dense
(
outputs
)
logits
-=
tf
.
expand_dims
(
tf
.
cast
(
tf
.
equal
(
token_ids
,
0
),
tf
.
float32
)
+
tf
.
sequence_mask
(
question_lengths
,
logits
.
shape
[
-
2
],
dtype
=
tf
.
float32
),
-
1
)
*
1e6
return
logits
class
TriviaQaModel
(
tf
.
keras
.
Model
):
"""Model for TriviaQA."""
def
__init__
(
self
,
model_config
:
encoders
.
EncoderConfig
,
sequence_length
:
int
,
**
kwargs
):
inputs
=
dict
(
token_ids
=
tf
.
keras
.
Input
((
sequence_length
,),
dtype
=
tf
.
int32
),
question_lengths
=
tf
.
keras
.
Input
((),
dtype
=
tf
.
int32
))
encoder
=
encoders
.
build_encoder
(
model_config
)
x
=
encoder
(
dict
(
input_word_ids
=
inputs
[
'token_ids'
],
input_mask
=
tf
.
cast
(
inputs
[
'token_ids'
]
>
0
,
tf
.
int32
),
input_type_ids
=
1
-
tf
.
sequence_mask
(
inputs
[
'question_lengths'
],
sequence_length
,
tf
.
int32
)))[
'sequence_output'
]
logits
=
TriviaQaHead
(
model_config
.
get
().
intermediate_size
,
dropout_rate
=
model_config
.
get
().
dropout_rate
,
attention_dropout_rate
=
model_config
.
get
().
attention_dropout_rate
)(
dict
(
token_embeddings
=
x
,
token_ids
=
inputs
[
'token_ids'
],
question_lengths
=
inputs
[
'question_lengths'
]))
super
(
TriviaQaModel
,
self
).
__init__
(
inputs
,
logits
,
**
kwargs
)
self
.
_encoder
=
encoder
@
property
def
encoder
(
self
):
return
self
.
_encoder
class
SpanOrCrossEntropyLoss
(
tf
.
keras
.
losses
.
Loss
):
"""Cross entropy loss for multiple correct answers.
See https://arxiv.org/abs/1710.10723.
"""
def
call
(
self
,
y_true
,
y_pred
):
y_pred_masked
=
y_pred
-
tf
.
cast
(
y_true
<
0.5
,
tf
.
float32
)
*
1e6
or_cross_entropy
=
(
tf
.
math
.
reduce_logsumexp
(
y_pred
,
axis
=-
2
)
-
tf
.
math
.
reduce_logsumexp
(
y_pred_masked
,
axis
=-
2
))
return
tf
.
math
.
reduce_sum
(
or_cross_entropy
,
-
1
)
def
smooth_labels
(
label_smoothing
,
labels
,
question_lengths
,
token_ids
):
mask
=
1.
-
(
tf
.
cast
(
tf
.
equal
(
token_ids
,
0
),
tf
.
float32
)
+
tf
.
sequence_mask
(
question_lengths
,
labels
.
shape
[
-
2
],
dtype
=
tf
.
float32
))
num_classes
=
tf
.
expand_dims
(
tf
.
math
.
reduce_sum
(
mask
,
-
1
,
keepdims
=
True
),
-
1
)
labels
=
(
1.
-
label_smoothing
)
*
labels
+
(
label_smoothing
/
num_classes
)
return
labels
*
tf
.
expand_dims
(
mask
,
-
1
)
official/nlp/projects/triviaqa/predict.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA script for inference."""
import
collections
import
contextlib
import
functools
import
json
import
operator
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
sentencepiece
as
spm
from
official.nlp.configs
import
encoders
# pylint: disable=unused-import
from
official.nlp.projects.triviaqa
import
evaluation
from
official.nlp.projects.triviaqa
import
inputs
from
official.nlp.projects.triviaqa
import
prediction
flags
.
DEFINE_string
(
'data_dir'
,
None
,
'TensorFlow Datasets directory.'
)
flags
.
DEFINE_enum
(
'split'
,
None
,
[
tfds
.
Split
.
TRAIN
,
tfds
.
Split
.
VALIDATION
,
tfds
.
Split
.
TEST
],
'For which split to generate predictions.'
)
flags
.
DEFINE_string
(
'predictions_path'
,
None
,
'Output for predictions.'
)
flags
.
DEFINE_string
(
'sentencepiece_model_path'
,
None
,
'Path to sentence piece model.'
)
flags
.
DEFINE_integer
(
'bigbird_block_size'
,
64
,
'Size of blocks for sparse block attention.'
)
flags
.
DEFINE_string
(
'saved_model_dir'
,
None
,
'Path from which to initialize model and weights.'
)
flags
.
DEFINE_integer
(
'sequence_length'
,
4096
,
'Maximum number of tokens.'
)
flags
.
DEFINE_integer
(
'global_sequence_length'
,
320
,
'Maximum number of global tokens.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'Size of batch.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Address of the TPU master.'
)
flags
.
DEFINE_integer
(
'decode_top_k'
,
8
,
'Maximum number of tokens to consider for begin/end.'
)
flags
.
DEFINE_integer
(
'decode_max_size'
,
16
,
'Maximum number of sentence pieces in an answer.'
)
FLAGS
=
flags
.
FLAGS
@
contextlib
.
contextmanager
def
worker_context
():
if
FLAGS
.
master
:
with
tf
.
device
(
'/job:worker'
)
as
d
:
yield
d
else
:
yield
def
read_sentencepiece_model
(
path
):
with
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
)
as
file
:
processor
=
spm
.
SentencePieceProcessor
()
processor
.
LoadFromSerializedProto
(
file
.
read
())
return
processor
def
predict
(
sp_processor
,
features_map_fn
,
logits_fn
,
decode_logits_fn
,
split_and_pad_fn
,
distribute_strategy
,
dataset
):
"""Make predictions."""
predictions
=
collections
.
defaultdict
(
list
)
for
_
,
features
in
dataset
.
enumerate
():
token_ids
=
features
[
'token_ids'
]
x
=
split_and_pad_fn
(
features_map_fn
(
features
))
logits
=
tf
.
concat
(
distribute_strategy
.
experimental_local_results
(
logits_fn
(
x
)),
0
)
logits
=
logits
[:
features
[
'token_ids'
].
shape
[
0
]]
end_limit
=
token_ids
.
row_lengths
()
-
1
# inclusive
begin
,
end
,
scores
=
decode_logits_fn
(
logits
,
end_limit
)
answers
=
prediction
.
decode_answer
(
features
[
'context'
],
begin
,
end
,
features
[
'token_offsets'
],
end_limit
).
numpy
()
for
j
,
(
qid
,
token_id
,
offset
,
score
,
answer
)
in
enumerate
(
zip
(
features
[
'qid'
].
numpy
(),
tf
.
gather
(
features
[
'token_ids'
],
begin
,
batch_dims
=
1
).
numpy
(),
tf
.
gather
(
features
[
'token_offsets'
],
begin
,
batch_dims
=
1
).
numpy
(),
scores
,
answers
)):
if
not
answer
:
logging
.
info
(
'%s: %s | NO_ANSWER, %f'
,
features
[
'id'
][
j
].
numpy
().
decode
(
'utf-8'
),
features
[
'question'
][
j
].
numpy
().
decode
(
'utf-8'
),
score
)
continue
if
sp_processor
.
IdToPiece
(
int
(
token_id
)).
startswith
(
'▁'
)
and
offset
>
0
:
answer
=
answer
[
1
:]
logging
.
info
(
'%s: %s | %s, %f'
,
features
[
'id'
][
j
].
numpy
().
decode
(
'utf-8'
),
features
[
'question'
][
j
].
numpy
().
decode
(
'utf-8'
),
answer
.
decode
(
'utf-8'
),
score
)
predictions
[
qid
.
decode
(
'utf-8'
)].
append
((
score
,
answer
.
decode
(
'utf-8'
)))
predictions
=
{
qid
:
evaluation
.
normalize_answer
(
sorted
(
answers
,
key
=
operator
.
itemgetter
(
0
),
reverse
=
True
)[
0
][
1
])
for
qid
,
answers
in
predictions
.
items
()
}
return
predictions
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
# Configure input processing.
sp_processor
=
read_sentencepiece_model
(
FLAGS
.
sentencepiece_model_path
)
features_map_fn
=
tf
.
function
(
functools
.
partial
(
inputs
.
features_map_fn
,
local_radius
=
FLAGS
.
bigbird_block_size
,
relative_pos_max_distance
=
24
,
use_hard_g2l_mask
=
True
,
sequence_length
=
FLAGS
.
sequence_length
,
global_sequence_length
=
FLAGS
.
global_sequence_length
,
padding_id
=
sp_processor
.
PieceToId
(
'<pad>'
),
eos_id
=
sp_processor
.
PieceToId
(
'</s>'
),
null_id
=
sp_processor
.
PieceToId
(
'<empty>'
),
cls_id
=
sp_processor
.
PieceToId
(
'<ans>'
),
sep_id
=
sp_processor
.
PieceToId
(
'<sep_0>'
)),
autograph
=
False
)
# Connect to TPU cluster.
if
FLAGS
.
master
:
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
master
)
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
strategy
=
tf
.
distribute
.
TPUStrategy
(
resolver
)
else
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
# Initialize datasets.
with
worker_context
():
_
=
tf
.
random
.
get_global_generator
()
dataset
=
inputs
.
read_batches
(
FLAGS
.
data_dir
,
FLAGS
.
split
,
FLAGS
.
batch_size
,
include_answers
=
False
)
# Initialize model and compile.
with
strategy
.
scope
():
model
=
tf
.
keras
.
models
.
load_model
(
FLAGS
.
saved_model_dir
,
compile
=
False
)
logging
.
info
(
'Model initialized. Beginning prediction loop.'
)
logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
distributed_logits_fn
,
model
))
decode_logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
decode_logits
,
FLAGS
.
decode_top_k
,
FLAGS
.
decode_max_size
))
split_and_pad_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
split_and_pad
,
strategy
,
FLAGS
.
batch_size
))
# Prediction strategy.
predict_fn
=
functools
.
partial
(
predict
,
sp_processor
=
sp_processor
,
features_map_fn
=
features_map_fn
,
logits_fn
=
logits_fn
,
decode_logits_fn
=
decode_logits_fn
,
split_and_pad_fn
=
split_and_pad_fn
,
distribute_strategy
=
strategy
,
dataset
=
dataset
)
with
worker_context
():
predictions
=
predict_fn
()
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
predictions_path
,
'w'
)
as
f
:
json
.
dump
(
predictions
,
f
)
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'split'
,
'predictions_path'
,
'saved_model_dir'
])
app
.
run
(
main
)
official/nlp/projects/triviaqa/prediction.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for inference."""
import
tensorflow
as
tf
def
split_and_pad
(
strategy
,
batch_size
,
x
):
"""Split and pad for interence."""
per_replica_size
=
batch_size
//
strategy
.
num_replicas_in_sync
def
slice_fn
(
x
,
i
):
begin
=
min
(
x
.
shape
[
0
],
i
*
per_replica_size
)
end
=
min
(
x
.
shape
[
0
],
(
i
+
1
)
*
per_replica_size
)
indices
=
tf
.
range
(
begin
,
end
,
dtype
=
tf
.
int32
)
return
tf
.
gather
(
x
,
tf
.
pad
(
indices
,
[[
0
,
per_replica_size
-
end
+
begin
]]))
# pylint: disable=g-long-lambda
return
tf
.
nest
.
map_structure
(
lambda
x
:
strategy
.
experimental_distribute_values_from_function
(
lambda
ctx
:
slice_fn
(
x
,
ctx
.
replica_id_in_sync_group
)),
x
)
# pylint: enable=g-long-lambda
def
decode_logits
(
top_k
,
max_size
,
logits
,
default
):
"""Get the span from logits."""
logits
=
tf
.
transpose
(
logits
,
[
0
,
2
,
1
])
values
,
indices
=
tf
.
math
.
top_k
(
logits
,
top_k
)
width
=
(
tf
.
expand_dims
(
indices
[:,
1
,
:],
-
2
)
-
tf
.
expand_dims
(
indices
[:,
0
,
:],
-
1
))
mask
=
tf
.
logical_and
(
width
>=
0
,
width
<=
max_size
)
scores
=
(
tf
.
expand_dims
(
values
[:,
0
,
:],
-
1
)
+
tf
.
expand_dims
(
values
[:,
1
,
:],
-
2
))
scores
=
tf
.
where
(
mask
,
scores
,
-
1e8
)
flat_indices
=
tf
.
argmax
(
tf
.
reshape
(
scores
,
(
-
1
,
top_k
*
top_k
)),
-
1
)
begin
=
tf
.
gather
(
indices
[:,
0
,
:],
tf
.
math
.
floordiv
(
flat_indices
,
top_k
),
batch_dims
=
1
)
end
=
tf
.
gather
(
indices
[:,
1
,
:],
tf
.
math
.
mod
(
flat_indices
,
top_k
),
batch_dims
=
1
)
reduced_mask
=
tf
.
math
.
reduce_any
(
mask
,
[
-
1
,
-
2
])
return
(
tf
.
where
(
reduced_mask
,
begin
,
default
),
tf
.
where
(
reduced_mask
,
end
,
default
),
tf
.
math
.
reduce_max
(
scores
,
[
-
1
,
-
2
]))
@
tf
.
function
def
decode_answer
(
context
,
begin
,
end
,
token_offsets
,
end_limit
):
i
=
tf
.
gather
(
token_offsets
,
begin
,
batch_dims
=
1
)
j
=
tf
.
gather
(
token_offsets
,
tf
.
minimum
(
end
+
1
,
end_limit
),
batch_dims
=
1
)
j
=
tf
.
where
(
end
==
end_limit
,
tf
.
cast
(
tf
.
strings
.
length
(
context
),
tf
.
int64
),
j
)
return
tf
.
strings
.
substr
(
context
,
i
,
j
-
i
)
def
distributed_logits_fn
(
model
,
x
):
return
model
.
distribute_strategy
.
run
(
lambda
x
:
model
(
x
,
training
=
False
),
args
=
(
x
,))
official/nlp/projects/triviaqa/preprocess.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for preprocessing TriviaQA data."""
import
bisect
import
json
import
operator
import
os
import
re
import
string
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Text
,
Tuple
from
absl
import
logging
import
apache_beam
as
beam
from
apache_beam
import
metrics
import
dataclasses
import
nltk
import
numpy
as
np
import
tensorflow.io.gfile
as
gfile
import
sentencepiece
as
spm
from
official.nlp.projects.triviaqa
import
evaluation
from
official.nlp.projects.triviaqa
import
sentencepiece_pb2
@
dataclasses
.
dataclass
class
Question
(
object
):
id
:
Text
value
:
Text
@
dataclasses
.
dataclass
class
EvidenceInfo
(
object
):
id
:
Text
source
:
Text
title
:
Text
@
dataclasses
.
dataclass
class
Evidence
(
object
):
info
:
EvidenceInfo
text
:
Text
@
dataclasses
.
dataclass
class
Answer
(
object
):
value
:
Text
aliases
:
List
[
Text
]
normalized_aliases
:
List
[
Text
]
@
dataclasses
.
dataclass
class
QuestionAnswer
(
object
):
question
:
Question
evidence_info
:
List
[
EvidenceInfo
]
answer
:
Optional
[
Answer
]
=
None
@
dataclasses
.
dataclass
class
QuestionAnswerEvidence
(
object
):
question
:
Question
evidence
:
Evidence
answer
:
Optional
[
Answer
]
=
None
@
dataclasses
.
dataclass
class
Features
(
object
):
id
:
Text
stride_index
:
int
question_id
:
Text
question
:
Text
context
:
bytes
token_ids
:
List
[
int
]
token_offsets
:
List
[
int
]
global_token_ids
:
List
[
int
]
segment_ids
:
List
[
int
]
@
dataclasses
.
dataclass
class
Paragraph
(
object
):
sentences
:
List
[
sentencepiece_pb2
.
SentencePieceText
]
size
:
int
@
dataclasses
.
dataclass
class
AnswerSpan
(
object
):
begin
:
int
# inclusive
end
:
int
# inclusive
text
:
Text
def
make_paragraph
(
sentence_tokenizer
:
nltk
.
tokenize
.
api
.
TokenizerI
,
processor
:
spm
.
SentencePieceProcessor
,
text
:
Text
,
paragraph_metric
:
Optional
[
metrics
.
Metrics
.
DelegatingDistribution
]
=
None
,
sentence_metric
:
Optional
[
metrics
.
Metrics
.
DelegatingDistribution
]
=
None
)
->
Paragraph
:
"""Tokenizes paragraphs."""
paragraph_size
=
0
sentences
=
[]
for
sentence
in
sentence_tokenizer
.
tokenize
(
text
):
sentencepiece_text
=
sentencepiece_pb2
.
SentencePieceText
.
FromString
(
processor
.
EncodeAsSerializedProto
(
sentence
))
paragraph_size
+=
len
(
sentencepiece_text
.
pieces
)
sentences
.
append
(
sentencepiece_text
)
if
sentence_metric
:
sentence_metric
.
update
(
len
(
sentencepiece_text
.
pieces
))
if
paragraph_metric
:
paragraph_metric
.
update
(
paragraph_size
)
return
Paragraph
(
sentences
=
sentences
,
size
=
paragraph_size
)
def
read_question_answers
(
json_path
:
Text
)
->
List
[
QuestionAnswer
]:
"""Read question answers."""
with
gfile
.
GFile
(
json_path
)
as
f
:
data
=
json
.
load
(
f
)[
'Data'
]
question_answers
=
[]
for
datum
in
data
:
question
=
Question
(
id
=
datum
[
'QuestionId'
],
value
=
datum
[
'Question'
])
if
'Answer'
in
datum
:
answer
=
Answer
(
value
=
datum
[
'Answer'
][
'Value'
],
aliases
=
datum
[
'Answer'
][
'Aliases'
],
normalized_aliases
=
datum
[
'Answer'
][
'NormalizedAliases'
])
else
:
answer
=
None
evidence_info
=
[]
for
key
in
[
'EntityPages'
,
'SearchResults'
]:
for
document
in
datum
.
get
(
key
,
[]):
evidence_info
.
append
(
EvidenceInfo
(
id
=
document
[
'Filename'
],
title
=
document
[
'Title'
],
source
=
key
))
question_answers
.
append
(
QuestionAnswer
(
question
=
question
,
evidence_info
=
evidence_info
,
answer
=
answer
))
return
question_answers
def
alias_answer
(
answer
:
Text
,
include
=
None
):
alias
=
answer
.
replace
(
'_'
,
' '
).
lower
()
exclude
=
set
(
string
.
punctuation
+
''
.
join
([
'‘'
,
'’'
,
'´'
,
'`'
]))
include
=
include
or
[]
alias
=
''
.
join
(
c
if
c
not
in
exclude
or
c
in
include
else
' '
for
c
in
alias
)
return
' '
.
join
(
alias
.
split
()).
strip
()
def
make_answer_set
(
answer
:
Answer
)
->
Set
[
Text
]:
"""Apply less aggressive normalization to the answer aliases."""
answers
=
[]
for
alias
in
[
answer
.
value
]
+
answer
.
aliases
:
answers
.
append
(
alias_answer
(
alias
))
answers
.
append
(
alias_answer
(
alias
,
[
','
,
'.'
]))
answers
.
append
(
alias_answer
(
alias
,
[
'-'
]))
answers
.
append
(
alias_answer
(
alias
,
[
','
,
'.'
,
'-'
]))
answers
.
append
(
alias_answer
(
alias
,
string
.
punctuation
))
return
set
(
answers
+
answer
.
normalized_aliases
)
def
find_answer_spans
(
text
:
bytes
,
answer_set
:
Set
[
Text
])
->
List
[
AnswerSpan
]:
"""Find answer spans."""
spans
=
[]
for
answer
in
answer_set
:
answer_regex
=
re
.
compile
(
re
.
escape
(
answer
).
encode
(
'utf-8'
).
replace
(
b
'
\\
'
,
b
'[ -]'
),
flags
=
re
.
IGNORECASE
)
for
match
in
re
.
finditer
(
answer_regex
,
text
):
spans
.
append
(
AnswerSpan
(
begin
=
match
.
start
(),
end
=
match
.
end
(),
text
=
match
.
group
(
0
).
decode
(
'utf-8'
)))
return
sorted
(
spans
,
key
=
operator
.
attrgetter
(
'begin'
))
def
realign_answer_span
(
features
:
Features
,
answer_set
:
Optional
[
Set
[
Text
]],
processor
:
spm
.
SentencePieceProcessor
,
span
:
AnswerSpan
)
->
Optional
[
AnswerSpan
]:
"""Align answer span to text with given tokens."""
i
=
bisect
.
bisect_left
(
features
.
token_offsets
,
span
.
begin
)
if
i
==
len
(
features
.
token_offsets
)
or
span
.
begin
<
features
.
token_offsets
[
i
]:
i
-=
1
j
=
i
+
1
answer_end
=
span
.
begin
+
len
(
span
.
text
.
encode
(
'utf-8'
))
while
(
j
<
len
(
features
.
token_offsets
)
and
features
.
token_offsets
[
j
]
<
answer_end
):
j
+=
1
j
-=
1
sp_answer
=
(
features
.
context
[
features
.
token_offsets
[
i
]:
features
.
token_offsets
[
j
+
1
]]
if
j
+
1
<
len
(
features
.
token_offsets
)
else
features
.
context
[
features
.
token_offsets
[
i
]:])
if
(
processor
.
IdToPiece
(
features
.
token_ids
[
i
]).
startswith
(
'▁'
)
and
features
.
token_offsets
[
i
]
>
0
):
sp_answer
=
sp_answer
[
1
:]
sp_answer
=
evaluation
.
normalize_answer
(
sp_answer
.
decode
(
'utf-8'
))
if
answer_set
is
not
None
and
sp_answer
not
in
answer_set
:
# No need to warn if the cause was breaking word boundaries.
if
len
(
sp_answer
)
and
not
len
(
sp_answer
)
>
len
(
evaluation
.
normalize_answer
(
span
.
text
)):
logging
.
warning
(
'%s: "%s" not in %s.'
,
features
.
question_id
,
sp_answer
,
answer_set
)
return
None
return
AnswerSpan
(
begin
=
i
,
end
=
j
,
text
=
span
.
text
)
def
read_sentencepiece_model
(
path
):
with
gfile
.
GFile
(
path
,
'rb'
)
as
file
:
processor
=
spm
.
SentencePieceProcessor
()
processor
.
LoadFromSerializedProto
(
file
.
read
())
return
processor
class
ReadEvidence
(
beam
.
DoFn
):
"""Function to read evidence."""
def
__init__
(
self
,
wikipedia_dir
:
Text
,
web_dir
:
Text
):
self
.
_wikipedia_dir
=
wikipedia_dir
self
.
_web_dir
=
web_dir
def
process
(
self
,
question_answer
:
QuestionAnswer
)
->
Generator
[
QuestionAnswerEvidence
,
None
,
None
]:
for
info
in
question_answer
.
evidence_info
:
if
info
.
source
==
'EntityPages'
:
evidence_path
=
os
.
path
.
join
(
self
.
_wikipedia_dir
,
info
.
id
)
elif
info
.
source
==
'SearchResult'
:
evidence_path
=
os
.
path
.
join
(
self
.
_web_dir
,
info
.
id
)
else
:
raise
ValueError
(
f
'Unknown evidence source:
{
info
.
source
}
.'
)
with
gfile
.
GFile
(
evidence_path
,
'rb'
)
as
f
:
text
=
f
.
read
().
decode
(
'utf-8'
)
metrics
.
Metrics
.
counter
(
'_'
,
'documents'
).
inc
()
yield
QuestionAnswerEvidence
(
question
=
question_answer
.
question
,
evidence
=
Evidence
(
info
=
info
,
text
=
text
),
answer
=
question_answer
.
answer
)
_CLS_PIECE
=
'<ans>'
_EOS_PIECE
=
'</s>'
_SEP_PIECE
=
'<sep_0>'
# _PARAGRAPH_SEP_PIECE = '<sep_1>'
_NULL_PIECE
=
'<empty>'
_QUESTION_PIECE
=
'<unused_34>'
class
MakeFeatures
(
beam
.
DoFn
):
"""Function to make features."""
def
__init__
(
self
,
sentencepiece_model_path
:
Text
,
max_num_tokens
:
int
,
max_num_global_tokens
:
int
,
stride
:
int
):
self
.
_sentencepiece_model_path
=
sentencepiece_model_path
self
.
_max_num_tokens
=
max_num_tokens
self
.
_max_num_global_tokens
=
max_num_global_tokens
self
.
_stride
=
stride
def
setup
(
self
):
self
.
_sentence_tokenizer
=
nltk
.
data
.
load
(
'tokenizers/punkt/english.pickle'
)
self
.
_sentencepiece_processor
=
read_sentencepiece_model
(
self
.
_sentencepiece_model_path
)
def
_make_features
(
self
,
stride_index
:
int
,
paragraph_texts
:
List
[
Text
],
paragraphs
:
List
[
Paragraph
],
question_answer_evidence
:
QuestionAnswerEvidence
,
ids
:
List
[
int
],
paragraph_offset
:
int
)
->
Tuple
[
int
,
Features
]:
global_ids
=
(
[
self
.
_sentencepiece_processor
.
PieceToId
(
_CLS_PIECE
)]
+
[
self
.
_sentencepiece_processor
.
PieceToId
(
_QUESTION_PIECE
)]
*
len
(
ids
))
segment_ids
=
[
i
+
1
for
i
in
range
(
len
(
ids
))]
# offset for CLS token
token_ids
,
sentences
=
[],
[]
offsets
,
offset
,
full_text
=
[
-
1
]
*
len
(
ids
),
0
,
True
for
i
in
range
(
paragraph_offset
,
len
(
paragraph_texts
)):
if
i
<
len
(
paragraphs
):
paragraph
=
paragraphs
[
i
]
else
:
paragraphs
.
append
(
make_paragraph
(
self
.
_sentence_tokenizer
,
self
.
_sentencepiece_processor
,
paragraph_texts
[
i
],
paragraph_metric
=
metrics
.
Metrics
.
distribution
(
'_'
,
'paragraphs'
),
sentence_metric
=
metrics
.
Metrics
.
distribution
(
'_'
,
'sentences'
)))
paragraph
=
paragraphs
[
-
1
]
for
sentence
in
paragraph
.
sentences
:
if
(
len
(
ids
)
+
len
(
token_ids
)
+
len
(
sentence
.
pieces
)
+
1
>=
self
.
_max_num_tokens
or
len
(
global_ids
)
>=
self
.
_max_num_global_tokens
):
full_text
=
False
break
for
j
,
piece
in
enumerate
(
sentence
.
pieces
):
token_ids
.
append
(
piece
.
id
)
segment_ids
.
append
(
len
(
global_ids
))
offsets
.
append
(
offset
+
piece
.
begin
)
if
j
==
0
and
sentences
:
offsets
[
-
1
]
-=
1
offset
+=
len
(
sentence
.
text
.
encode
(
'utf-8'
))
+
1
global_ids
.
append
(
self
.
_sentencepiece_processor
.
PieceToId
(
_EOS_PIECE
))
sentences
.
append
(
sentence
.
text
)
if
not
full_text
:
break
context
=
' '
.
join
(
sentences
).
encode
(
'utf-8'
)
token_ids
.
append
(
self
.
_sentencepiece_processor
.
PieceToId
(
_NULL_PIECE
))
offsets
.
append
(
len
(
context
))
segment_ids
.
append
(
0
)
next_paragraph_index
=
len
(
paragraph_texts
)
if
not
full_text
and
self
.
_stride
>
0
:
shift
=
paragraphs
[
paragraph_offset
].
size
next_paragraph_index
=
paragraph_offset
+
1
while
(
next_paragraph_index
<
len
(
paragraphs
)
and
shift
+
paragraphs
[
next_paragraph_index
].
size
<=
self
.
_stride
):
shift
+=
paragraphs
[
next_paragraph_index
].
size
next_paragraph_index
+=
1
return
next_paragraph_index
,
Features
(
id
=
'{}--{}'
.
format
(
question_answer_evidence
.
question
.
id
,
question_answer_evidence
.
evidence
.
info
.
id
),
stride_index
=
stride_index
,
question_id
=
question_answer_evidence
.
question
.
id
,
question
=
question_answer_evidence
.
question
.
value
,
context
=
context
,
token_ids
=
ids
+
token_ids
,
global_token_ids
=
global_ids
,
segment_ids
=
segment_ids
,
token_offsets
=
offsets
)
def
process
(
self
,
question_answer_evidence
:
QuestionAnswerEvidence
)
->
Generator
[
Features
,
None
,
None
]:
# Tokenize question which is shared among all examples.
ids
=
(
self
.
_sentencepiece_processor
.
EncodeAsIds
(
question_answer_evidence
.
question
.
value
)
+
[
self
.
_sentencepiece_processor
.
PieceToId
(
_SEP_PIECE
)])
paragraph_texts
=
list
(
filter
(
lambda
p
:
p
,
map
(
lambda
p
:
p
.
strip
(),
question_answer_evidence
.
evidence
.
text
.
split
(
'
\n
'
))))
stride_index
,
paragraphs
,
paragraph_index
=
0
,
[],
0
while
paragraph_index
<
len
(
paragraph_texts
):
paragraph_index
,
features
=
self
.
_make_features
(
stride_index
,
paragraph_texts
,
paragraphs
,
question_answer_evidence
,
ids
,
paragraph_index
)
stride_index
+=
1
yield
features
def
_handle_exceptional_examples
(
features
:
Features
,
processor
:
spm
.
SentencePieceProcessor
)
->
List
[
AnswerSpan
]:
"""Special cases in data."""
if
features
.
id
==
'qw_6687--Viola.txt'
:
pattern
=
'three strings in common—G, D, and A'
.
encode
(
'utf-8'
)
i
=
features
.
context
.
find
(
pattern
)
if
i
!=
-
1
:
span
=
AnswerSpan
(
i
+
len
(
pattern
)
-
1
,
i
+
len
(
pattern
),
'A'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
assert
span
is
not
None
,
'Span should exist.'
return
[
span
]
if
features
.
id
==
'sfq_26183--Vitamin_A.txt'
:
pattern
=
(
'Vitamin A is a group of unsaturated nutritional organic '
'compounds that includes retinol'
).
encode
(
'utf-8'
)
i
=
features
.
context
.
find
(
pattern
)
if
i
!=
-
1
:
span
=
AnswerSpan
(
i
+
pattern
.
find
(
b
'A'
),
i
+
pattern
.
find
(
b
'A'
)
+
1
,
'A'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
assert
span
is
not
None
,
'Span should exist.'
spans
=
[
span
]
span
=
AnswerSpan
(
i
,
i
+
pattern
.
find
(
b
'A'
)
+
1
,
'Vitamin A'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
return
spans
+
[
span
]
if
features
.
id
==
'odql_292--Colombia.txt'
:
pattern
=
b
'Colombia is the third-most populous country in Latin America'
i
=
features
.
context
.
find
(
pattern
)
if
i
!=
-
1
:
span
=
AnswerSpan
(
i
,
i
+
len
(
b
'Colombia'
),
'Colombia'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
assert
span
is
not
None
,
'Span should exist.'
return
[
span
]
if
features
.
id
==
'tc_1648--Vietnam.txt'
:
pattern
=
'Bảo Đại'
.
encode
(
'utf-8'
)
i
=
features
.
context
.
find
(
pattern
)
if
i
!=
-
1
:
span
=
AnswerSpan
(
i
,
i
+
len
(
pattern
),
'Bảo Đại'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
assert
span
is
not
None
,
'Span should exist.'
return
[
span
]
if
features
.
id
==
'sfq_22225--Irish_mythology.txt'
:
pattern
=
'Tír na nÓg'
.
encode
(
'utf-8'
)
spans
=
[]
i
=
0
while
features
.
context
.
find
(
pattern
,
i
)
!=
-
1
:
i
=
features
.
context
.
find
(
pattern
)
span
=
AnswerSpan
(
i
,
i
+
len
(
pattern
),
'Tír na nÓg'
)
span
=
realign_answer_span
(
features
,
None
,
processor
,
span
)
assert
span
is
not
None
,
'Span should exist.'
spans
.
append
(
span
)
i
+=
len
(
pattern
)
return
spans
return
[]
class
FindAnswerSpans
(
beam
.
DoFn
):
"""Find answer spans in document."""
def
__init__
(
self
,
sentencepiece_model_path
:
Text
):
self
.
_sentencepiece_model_path
=
sentencepiece_model_path
def
setup
(
self
):
self
.
_sentencepiece_processor
=
read_sentencepiece_model
(
self
.
_sentencepiece_model_path
)
def
process
(
self
,
element
:
Tuple
[
Text
,
List
[
Features
]],
answer_sets
:
Dict
[
Text
,
Set
[
Text
]],
)
->
Generator
[
Tuple
[
Features
,
List
[
AnswerSpan
]],
None
,
None
]:
question_id
,
features
=
element
answer_set
=
answer_sets
[
question_id
]
has_answer
=
False
for
feature
in
features
:
answer_spans
=
[]
for
answer_span
in
find_answer_spans
(
feature
.
context
,
answer_set
):
realigned_answer_span
=
realign_answer_span
(
feature
,
answer_set
,
self
.
_sentencepiece_processor
,
answer_span
)
if
realigned_answer_span
:
answer_spans
.
append
(
realigned_answer_span
)
if
not
answer_spans
:
answer_spans
=
_handle_exceptional_examples
(
feature
,
self
.
_sentencepiece_processor
)
if
answer_spans
:
has_answer
=
True
else
:
metrics
.
Metrics
.
counter
(
'_'
,
'answerless_examples'
).
inc
()
yield
feature
,
answer_spans
if
not
has_answer
:
metrics
.
Metrics
.
counter
(
'_'
,
'answerless_questions'
).
inc
()
logging
.
error
(
'Question %s has no answer.'
,
question_id
)
def
make_example
(
features
:
Features
,
labels
:
Optional
[
List
[
AnswerSpan
]]
=
None
)
->
Tuple
[
Text
,
Dict
[
Text
,
Any
]]:
"""Make an example."""
feature
=
{
'id'
:
features
.
id
,
'qid'
:
features
.
question_id
,
'question'
:
features
.
question
,
'context'
:
features
.
context
,
'token_ids'
:
features
.
token_ids
,
'token_offsets'
:
features
.
token_offsets
,
'segment_ids'
:
features
.
segment_ids
,
'global_token_ids'
:
features
.
global_token_ids
,
}
if
labels
:
answers
=
set
((
label
.
begin
,
label
.
end
)
for
label
in
labels
)
feature
[
'answers'
]
=
np
.
array
([
list
(
answer
)
for
answer
in
answers
],
np
.
int64
)
else
:
feature
[
'answers'
]
=
np
.
zeros
([
0
,
2
],
np
.
int64
)
metrics
.
Metrics
.
counter
(
'_'
,
'examples'
).
inc
()
return
f
'
{
features
.
id
}
--
{
features
.
stride_index
}
'
,
feature
def
make_pipeline
(
root
:
beam
.
Pipeline
,
question_answers
:
List
[
QuestionAnswer
],
answer
:
bool
,
max_num_tokens
:
int
,
max_num_global_tokens
:
int
,
stride
:
int
,
sentencepiece_model_path
:
Text
,
wikipedia_dir
:
Text
,
web_dir
:
Text
):
"""Makes a Beam pipeline."""
question_answers
=
(
root
|
'CreateQuestionAnswers'
>>
beam
.
Create
(
question_answers
))
features
=
(
question_answers
|
'ReadEvidence'
>>
beam
.
ParDo
(
ReadEvidence
(
wikipedia_dir
=
wikipedia_dir
,
web_dir
=
web_dir
))
|
'MakeFeatures'
>>
beam
.
ParDo
(
MakeFeatures
(
sentencepiece_model_path
=
sentencepiece_model_path
,
max_num_tokens
=
max_num_tokens
,
max_num_global_tokens
=
max_num_global_tokens
,
stride
=
stride
)))
if
answer
:
features
=
features
|
'KeyFeature'
>>
beam
.
Map
(
lambda
feature
:
(
feature
.
question_id
,
feature
))
# pylint: disable=g-long-lambda
answer_sets
=
(
question_answers
|
'MakeAnswerSet'
>>
beam
.
Map
(
lambda
qa
:
(
qa
.
question
.
id
,
make_answer_set
(
qa
.
answer
))))
# pylint: enable=g-long-lambda
examples
=
(
features
|
beam
.
GroupByKey
()
|
'FindAnswerSpans'
>>
beam
.
ParDo
(
FindAnswerSpans
(
sentencepiece_model_path
),
answer_sets
=
beam
.
pvalue
.
AsDict
(
answer_sets
))
|
'MakeExamplesWithLabels'
>>
beam
.
MapTuple
(
make_example
))
else
:
examples
=
features
|
'MakeExamples'
>>
beam
.
Map
(
make_example
)
return
examples
official/nlp/projects/triviaqa/sentencepiece_pb2.py
0 → 100755
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# pylint: disable=bad-continuation
# pylint: disable=protected-access
# Generated by the protocol buffer compiler. DO NOT EDIT!
"""Generated protocol buffer code."""
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
message
as
_message
from
google.protobuf
import
reflection
as
_reflection
from
google.protobuf
import
symbol_database
as
_symbol_database
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor
.
FileDescriptor
(
name
=
'third_party/sentencepiece/src/sentencepiece.proto'
,
package
=
'sentencepiece'
,
syntax
=
'proto2'
,
serialized_options
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
serialized_pb
=
b
'
\n
1third_party/sentencepiece/src/sentencepiece.proto
\x12\r
sentencepiece
\"\xdf\x01\n\x11
SentencePieceText
\x12\x0c\n\x04
text
\x18\x01
\x01
(
\t\x12
>
\n\x06
pieces
\x18\x02
\x03
(
\x0b\x32
..sentencepiece.SentencePieceText.SentencePiece
\x12\r\n\x05
score
\x18\x03
\x01
(
\x02\x1a\x62\n\r
SentencePiece
\x12\r\n\x05
piece
\x18\x01
\x01
(
\t\x12\n\n\x02
id
\x18\x02
\x01
(
\r\x12\x0f\n\x07
surface
\x18\x03
\x01
(
\t\x12\r\n\x05\x62\x65
gin
\x18\x04
\x01
(
\r\x12\x0b\n\x03\x65
nd
\x18\x05
\x01
(
\r
*
\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02
*
\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"
J
\n\x16
NBestSentencePieceText
\x12\x30\n\x06
nbests
\x18\x01
\x03
(
\x0b\x32
.sentencepiece.SentencePieceText'
)
_SENTENCEPIECETEXT_SENTENCEPIECE
=
_descriptor
.
Descriptor
(
name
=
'SentencePiece'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'piece'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece.piece'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
b
''
.
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'id'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece.id'
,
index
=
1
,
number
=
2
,
type
=
13
,
cpp_type
=
3
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'surface'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece.surface'
,
index
=
2
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
b
''
.
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'begin'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece.begin'
,
index
=
3
,
number
=
4
,
type
=
13
,
cpp_type
=
3
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'end'
,
full_name
=
'sentencepiece.SentencePieceText.SentencePiece.end'
,
index
=
4
,
number
=
5
,
type
=
13
,
cpp_type
=
3
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
serialized_options
=
None
,
is_extendable
=
True
,
syntax
=
'proto2'
,
extension_ranges
=
[
(
200
,
536870912
),
],
oneofs
=
[],
serialized_start
=
183
,
serialized_end
=
281
,
)
_SENTENCEPIECETEXT
=
_descriptor
.
Descriptor
(
name
=
'SentencePieceText'
,
full_name
=
'sentencepiece.SentencePieceText'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'text'
,
full_name
=
'sentencepiece.SentencePieceText.text'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
b
''
.
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'pieces'
,
full_name
=
'sentencepiece.SentencePieceText.pieces'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'score'
,
full_name
=
'sentencepiece.SentencePieceText.score'
,
index
=
2
,
number
=
3
,
type
=
2
,
cpp_type
=
6
,
label
=
1
,
has_default_value
=
False
,
default_value
=
float
(
0
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
],
extensions
=
[],
nested_types
=
[
_SENTENCEPIECETEXT_SENTENCEPIECE
,
],
enum_types
=
[],
serialized_options
=
None
,
is_extendable
=
True
,
syntax
=
'proto2'
,
extension_ranges
=
[
(
200
,
536870912
),
],
oneofs
=
[],
serialized_start
=
69
,
serialized_end
=
292
,
)
_NBESTSENTENCEPIECETEXT
=
_descriptor
.
Descriptor
(
name
=
'NBestSentencePieceText'
,
full_name
=
'sentencepiece.NBestSentencePieceText'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'nbests'
,
full_name
=
'sentencepiece.NBestSentencePieceText.nbests'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
serialized_options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
294
,
serialized_end
=
368
,
)
_SENTENCEPIECETEXT_SENTENCEPIECE
.
containing_type
=
_SENTENCEPIECETEXT
_SENTENCEPIECETEXT
.
fields_by_name
[
'pieces'
].
message_type
=
_SENTENCEPIECETEXT_SENTENCEPIECE
_NBESTSENTENCEPIECETEXT
.
fields_by_name
[
'nbests'
].
message_type
=
_SENTENCEPIECETEXT
DESCRIPTOR
.
message_types_by_name
[
'SentencePieceText'
]
=
_SENTENCEPIECETEXT
DESCRIPTOR
.
message_types_by_name
[
'NBestSentencePieceText'
]
=
_NBESTSENTENCEPIECETEXT
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
SentencePieceText
=
_reflection
.
GeneratedProtocolMessageType
(
'SentencePieceText'
,
(
_message
.
Message
,),
{
'SentencePiece'
:
_reflection
.
GeneratedProtocolMessageType
(
'SentencePiece'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_SENTENCEPIECETEXT_SENTENCEPIECE
,
'__module__'
:
'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.SentencePieceText.SentencePiece)
}),
'DESCRIPTOR'
:
_SENTENCEPIECETEXT
,
'__module__'
:
'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.SentencePieceText)
})
_sym_db
.
RegisterMessage
(
SentencePieceText
)
_sym_db
.
RegisterMessage
(
SentencePieceText
.
SentencePiece
)
NBestSentencePieceText
=
_reflection
.
GeneratedProtocolMessageType
(
'NBestSentencePieceText'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_NBESTSENTENCEPIECETEXT
,
'__module__'
:
'official.nlp.projects.triviaqa.sentencepiece_pb2'
# @@protoc_insertion_point(class_scope:sentencepiece.NBestSentencePieceText)
})
_sym_db
.
RegisterMessage
(
NBestSentencePieceText
)
# @@protoc_insertion_point(module_scope)
official/nlp/projects/triviaqa/train.py
0 → 100644
View file @
0ab5dcbf
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TriviaQA training script."""
import
collections
import
contextlib
import
functools
import
json
import
operator
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
sentencepiece
as
spm
from
official.nlp
import
optimization
as
nlp_optimization
from
official.nlp.configs
import
encoders
from
official.nlp.projects.triviaqa
import
evaluation
from
official.nlp.projects.triviaqa
import
inputs
from
official.nlp.projects.triviaqa
import
modeling
from
official.nlp.projects.triviaqa
import
prediction
flags
.
DEFINE_string
(
'data_dir'
,
None
,
'Data directory for TensorFlow Datasets.'
)
flags
.
DEFINE_string
(
'validation_gold_path'
,
None
,
'Path to golden validation. Usually, the wikipedia-dev.json file.'
)
flags
.
DEFINE_string
(
'model_dir'
,
None
,
'Directory for checkpoints and summaries.'
)
flags
.
DEFINE_string
(
'model_config_path'
,
None
,
'JSON file containing model coniguration.'
)
flags
.
DEFINE_string
(
'sentencepiece_model_path'
,
None
,
'Path to sentence piece model.'
)
flags
.
DEFINE_enum
(
'encoder'
,
'bigbird'
,
[
'bert'
,
'bigbird'
,
'albert'
,
'mobilebert'
],
'Which transformer encoder model to use.'
)
flags
.
DEFINE_integer
(
'bigbird_block_size'
,
64
,
'Size of blocks for sparse block attention.'
)
flags
.
DEFINE_string
(
'init_checkpoint_path'
,
None
,
'Path from which to initialize weights.'
)
flags
.
DEFINE_integer
(
'train_sequence_length'
,
4096
,
'Maximum number of tokens for training.'
)
flags
.
DEFINE_integer
(
'train_global_sequence_length'
,
320
,
'Maximum number of global tokens for training.'
)
flags
.
DEFINE_integer
(
'validation_sequence_length'
,
4096
,
'Maximum number of tokens for validation.'
)
flags
.
DEFINE_integer
(
'validation_global_sequence_length'
,
320
,
'Maximum number of global tokens for validation.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'Size of batch.'
)
flags
.
DEFINE_string
(
'master'
,
''
,
'Address of the TPU master.'
)
flags
.
DEFINE_integer
(
'decode_top_k'
,
8
,
'Maximum number of tokens to consider for begin/end.'
)
flags
.
DEFINE_integer
(
'decode_max_size'
,
16
,
'Maximum number of sentence pieces in an answer.'
)
flags
.
DEFINE_float
(
'dropout_rate'
,
0.1
,
'Dropout rate for hidden layers.'
)
flags
.
DEFINE_float
(
'attention_dropout_rate'
,
0.3
,
'Dropout rate for attention layers.'
)
flags
.
DEFINE_float
(
'label_smoothing'
,
1e-1
,
'Degree of label smoothing.'
)
flags
.
DEFINE_multi_string
(
'gin_bindings'
,
[],
'Gin bindings to override the values set in the config files'
)
FLAGS
=
flags
.
FLAGS
@
contextlib
.
contextmanager
def
worker_context
():
if
FLAGS
.
master
:
with
tf
.
device
(
'/job:worker'
)
as
d
:
yield
d
else
:
yield
def
read_sentencepiece_model
(
path
):
with
tf
.
io
.
gfile
.
GFile
(
path
,
'rb'
)
as
file
:
processor
=
spm
.
SentencePieceProcessor
()
processor
.
LoadFromSerializedProto
(
file
.
read
())
return
processor
# Rename old BERT v1 configuration parameters.
_MODEL_CONFIG_REPLACEMENTS
=
{
'num_hidden_layers'
:
'num_layers'
,
'attention_probs_dropout_prob'
:
'attention_dropout_rate'
,
'hidden_dropout_prob'
:
'dropout_rate'
,
'hidden_act'
:
'hidden_activation'
,
'window_size'
:
'block_size'
,
}
def
read_model_config
(
encoder
,
path
,
bigbird_block_size
=
None
)
->
encoders
.
EncoderConfig
:
"""Merges the JSON configuration into the encoder configuration."""
with
tf
.
io
.
gfile
.
GFile
(
path
)
as
f
:
model_config
=
json
.
load
(
f
)
for
key
,
value
in
_MODEL_CONFIG_REPLACEMENTS
.
items
():
if
key
in
model_config
:
model_config
[
value
]
=
model_config
.
pop
(
key
)
model_config
[
'attention_dropout_rate'
]
=
FLAGS
.
attention_dropout_rate
model_config
[
'dropout_rate'
]
=
FLAGS
.
dropout_rate
model_config
[
'block_size'
]
=
bigbird_block_size
encoder_config
=
encoders
.
EncoderConfig
(
type
=
encoder
)
# Override the default config with those loaded from the JSON file.
encoder_config_keys
=
encoder_config
.
get
().
as_dict
().
keys
()
overrides
=
{}
for
key
,
value
in
model_config
.
items
():
if
key
in
encoder_config_keys
:
overrides
[
key
]
=
value
else
:
logging
.
warning
(
'Ignoring config parameter %s=%s'
,
key
,
value
)
encoder_config
.
get
().
override
(
overrides
)
return
encoder_config
@
gin
.
configurable
(
blacklist
=
[
'model'
,
'strategy'
,
'train_dataset'
,
'model_dir'
,
'init_checkpoint_path'
,
'evaluate_fn'
,
])
def
fit
(
model
,
strategy
,
train_dataset
,
model_dir
,
init_checkpoint_path
=
None
,
evaluate_fn
=
None
,
learning_rate
=
1e-5
,
learning_rate_polynomial_decay_rate
=
1.
,
weight_decay_rate
=
1e-1
,
num_warmup_steps
=
5000
,
num_decay_steps
=
51000
,
num_epochs
=
6
):
"""Train and evaluate."""
hparams
=
dict
(
learning_rate
=
learning_rate
,
num_decay_steps
=
num_decay_steps
,
num_warmup_steps
=
num_warmup_steps
,
num_epochs
=
num_epochs
,
weight_decay_rate
=
weight_decay_rate
,
dropout_rate
=
FLAGS
.
dropout_rate
,
attention_dropout_rate
=
FLAGS
.
attention_dropout_rate
,
label_smoothing
=
FLAGS
.
label_smoothing
)
logging
.
info
(
hparams
)
learning_rate_schedule
=
nlp_optimization
.
WarmUp
(
learning_rate
,
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
(
learning_rate
,
num_decay_steps
,
end_learning_rate
=
0.
,
power
=
learning_rate_polynomial_decay_rate
),
num_warmup_steps
)
with
strategy
.
scope
():
optimizer
=
nlp_optimization
.
AdamWeightDecay
(
learning_rate_schedule
,
weight_decay_rate
=
weight_decay_rate
,
epsilon
=
1e-6
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'layer_norm'
,
'bias'
])
model
.
compile
(
optimizer
,
loss
=
modeling
.
SpanOrCrossEntropyLoss
())
def
init_fn
(
init_checkpoint_path
):
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
model
.
encoder
)
ckpt
.
restore
(
init_checkpoint_path
).
assert_existing_objects_matched
()
with
worker_context
():
ckpt_manager
=
tf
.
train
.
CheckpointManager
(
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
optimizer
),
model_dir
,
max_to_keep
=
None
,
init_fn
=
(
functools
.
partial
(
init_fn
,
init_checkpoint_path
)
if
init_checkpoint_path
else
None
))
with
strategy
.
scope
():
ckpt_manager
.
restore_or_initialize
()
val_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
'val'
))
best_exact_match
=
0.
for
epoch
in
range
(
len
(
ckpt_manager
.
checkpoints
),
num_epochs
):
model
.
fit
(
train_dataset
,
callbacks
=
[
tf
.
keras
.
callbacks
.
TensorBoard
(
model_dir
,
write_graph
=
False
),
])
ckpt_path
=
ckpt_manager
.
save
()
if
evaluate_fn
is
None
:
continue
metrics
=
evaluate_fn
()
logging
.
info
(
'Epoch %d: %s'
,
epoch
+
1
,
metrics
)
if
best_exact_match
<
metrics
[
'exact_match'
]:
best_exact_match
=
metrics
[
'exact_match'
]
model
.
save
(
os
.
path
.
join
(
model_dir
,
'export'
),
include_optimizer
=
False
)
logging
.
info
(
'Exporting %s as SavedModel.'
,
ckpt_path
)
with
val_summary_writer
.
as_default
():
for
name
,
data
in
metrics
.
items
():
tf
.
summary
.
scalar
(
name
,
data
,
epoch
+
1
)
def
evaluate
(
sp_processor
,
features_map_fn
,
labels_map_fn
,
logits_fn
,
decode_logits_fn
,
split_and_pad_fn
,
distribute_strategy
,
validation_dataset
,
ground_truth
):
"""Run evaluation."""
loss_metric
=
tf
.
keras
.
metrics
.
Mean
()
@
tf
.
function
def
update_loss
(
y
,
logits
):
loss_fn
=
modeling
.
SpanOrCrossEntropyLoss
(
reduction
=
tf
.
keras
.
losses
.
Reduction
.
NONE
)
return
loss_metric
(
loss_fn
(
y
,
logits
))
predictions
=
collections
.
defaultdict
(
list
)
for
_
,
(
features
,
labels
)
in
validation_dataset
.
enumerate
():
token_ids
=
features
[
'token_ids'
]
y
=
labels_map_fn
(
token_ids
,
labels
)
x
=
split_and_pad_fn
(
features_map_fn
(
features
))
logits
=
tf
.
concat
(
distribute_strategy
.
experimental_local_results
(
logits_fn
(
x
)),
0
)
logits
=
logits
[:
features
[
'token_ids'
].
shape
[
0
]]
update_loss
(
y
,
logits
)
end_limit
=
token_ids
.
row_lengths
()
-
1
# inclusive
begin
,
end
,
scores
=
decode_logits_fn
(
logits
,
end_limit
)
answers
=
prediction
.
decode_answer
(
features
[
'context'
],
begin
,
end
,
features
[
'token_offsets'
],
end_limit
).
numpy
()
for
_
,
(
qid
,
token_id
,
offset
,
score
,
answer
)
in
enumerate
(
zip
(
features
[
'qid'
].
numpy
(),
tf
.
gather
(
features
[
'token_ids'
],
begin
,
batch_dims
=
1
).
numpy
(),
tf
.
gather
(
features
[
'token_offsets'
],
begin
,
batch_dims
=
1
).
numpy
(),
scores
,
answers
)):
if
not
answer
:
continue
if
sp_processor
.
IdToPiece
(
int
(
token_id
)).
startswith
(
'▁'
)
and
offset
>
0
:
answer
=
answer
[
1
:]
predictions
[
qid
.
decode
(
'utf-8'
)].
append
((
score
,
answer
.
decode
(
'utf-8'
)))
predictions
=
{
qid
:
evaluation
.
normalize_answer
(
sorted
(
answers
,
key
=
operator
.
itemgetter
(
0
),
reverse
=
True
)[
0
][
1
])
for
qid
,
answers
in
predictions
.
items
()
}
metrics
=
evaluation
.
evaluate_triviaqa
(
ground_truth
,
predictions
,
mute
=
True
)
metrics
[
'loss'
]
=
loss_metric
.
result
().
numpy
()
return
metrics
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
gin
.
parse_config
(
FLAGS
.
gin_bindings
)
model_config
=
read_model_config
(
FLAGS
.
encoder
,
FLAGS
.
model_config_path
,
bigbird_block_size
=
FLAGS
.
bigbird_block_size
)
logging
.
info
(
model_config
.
get
().
as_dict
())
# Configure input processing.
sp_processor
=
read_sentencepiece_model
(
FLAGS
.
sentencepiece_model_path
)
features_map_fn
=
functools
.
partial
(
inputs
.
features_map_fn
,
local_radius
=
FLAGS
.
bigbird_block_size
,
relative_pos_max_distance
=
24
,
use_hard_g2l_mask
=
True
,
padding_id
=
sp_processor
.
PieceToId
(
'<pad>'
),
eos_id
=
sp_processor
.
PieceToId
(
'</s>'
),
null_id
=
sp_processor
.
PieceToId
(
'<empty>'
),
cls_id
=
sp_processor
.
PieceToId
(
'<ans>'
),
sep_id
=
sp_processor
.
PieceToId
(
'<sep_0>'
))
train_features_map_fn
=
tf
.
function
(
functools
.
partial
(
features_map_fn
,
sequence_length
=
FLAGS
.
train_sequence_length
,
global_sequence_length
=
FLAGS
.
train_global_sequence_length
),
autograph
=
False
)
train_labels_map_fn
=
tf
.
function
(
functools
.
partial
(
inputs
.
labels_map_fn
,
sequence_length
=
FLAGS
.
train_sequence_length
))
# Connect to TPU cluster.
if
FLAGS
.
master
:
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
master
)
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
strategy
=
tf
.
distribute
.
TPUStrategy
(
resolver
)
else
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
# Initialize datasets.
with
worker_context
():
_
=
tf
.
random
.
get_global_generator
()
train_dataset
=
inputs
.
read_batches
(
FLAGS
.
data_dir
,
tfds
.
Split
.
TRAIN
,
FLAGS
.
batch_size
,
shuffle
=
True
,
drop_final_batch
=
True
)
validation_dataset
=
inputs
.
read_batches
(
FLAGS
.
data_dir
,
tfds
.
Split
.
VALIDATION
,
FLAGS
.
batch_size
)
def
train_map_fn
(
x
,
y
):
features
=
train_features_map_fn
(
x
)
labels
=
modeling
.
smooth_labels
(
FLAGS
.
label_smoothing
,
train_labels_map_fn
(
x
[
'token_ids'
],
y
),
features
[
'question_lengths'
],
features
[
'token_ids'
])
return
features
,
labels
train_dataset
=
train_dataset
.
map
(
train_map_fn
,
16
).
prefetch
(
16
)
# Initialize model and compile.
with
strategy
.
scope
():
model
=
modeling
.
TriviaQaModel
(
model_config
,
FLAGS
.
train_sequence_length
)
logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
distributed_logits_fn
,
model
))
decode_logits_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
decode_logits
,
FLAGS
.
decode_top_k
,
FLAGS
.
decode_max_size
))
split_and_pad_fn
=
tf
.
function
(
functools
.
partial
(
prediction
.
split_and_pad
,
strategy
,
FLAGS
.
batch_size
))
# Evaluation strategy.
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
validation_gold_path
)
as
f
:
ground_truth
=
{
datum
[
'QuestionId'
]:
datum
[
'Answer'
]
for
datum
in
json
.
load
(
f
)[
'Data'
]
}
validation_features_map_fn
=
tf
.
function
(
functools
.
partial
(
features_map_fn
,
sequence_length
=
FLAGS
.
validation_sequence_length
,
global_sequence_length
=
FLAGS
.
validation_global_sequence_length
),
autograph
=
False
)
validation_labels_map_fn
=
tf
.
function
(
functools
.
partial
(
inputs
.
labels_map_fn
,
sequence_length
=
FLAGS
.
validation_sequence_length
))
evaluate_fn
=
functools
.
partial
(
evaluate
,
sp_processor
=
sp_processor
,
features_map_fn
=
validation_features_map_fn
,
labels_map_fn
=
validation_labels_map_fn
,
logits_fn
=
logits_fn
,
decode_logits_fn
=
decode_logits_fn
,
split_and_pad_fn
=
split_and_pad_fn
,
distribute_strategy
=
strategy
,
validation_dataset
=
validation_dataset
,
ground_truth
=
ground_truth
)
logging
.
info
(
'Model initialized. Beginning training fit loop.'
)
fit
(
model
,
strategy
,
train_dataset
,
FLAGS
.
model_dir
,
FLAGS
.
init_checkpoint_path
,
evaluate_fn
)
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'model_config_path'
,
'model_dir'
,
'sentencepiece_model_path'
,
'validation_gold_path'
])
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment