Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
db39ef82
Commit
db39ef82
authored
Jun 28, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 318755856
parent
997eaa19
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
202 additions
and
31 deletions
+202
-31
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+7
-0
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+136
-6
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+59
-25
No files found.
official/nlp/configs/bert.py
View file @
db39ef82
...
...
@@ -126,10 +126,17 @@ class QADataConfig(cfg.DataConfig):
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
input_preprocessed_data_path
:
str
=
""
version_2_with_negative
:
bool
=
False
doc_stride
:
int
=
128
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
query_length
:
int
=
64
drop_remainder
:
bool
=
False
vocab_file
:
str
=
""
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
@
dataclasses
.
dataclass
...
...
official/nlp/tasks/question_answering.py
View file @
db39ef82
...
...
@@ -14,7 +14,10 @@
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import
logging
import
collections
import
json
import
os
from
absl
import
logging
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
...
...
@@ -22,7 +25,12 @@ import tensorflow_hub as hub
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.bert
import
tokenization
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
...
...
@@ -33,6 +41,9 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
n_best_size
:
int
=
20
max_answer_length
:
int
=
30
null_score_diff_threshold
:
float
=
0.0
model
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
...
...
@@ -41,10 +52,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
@
base_task
.
register_task_cls
(
QuestionAnsweringConfig
)
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering.
TODO(lehou): Add post-processing.
"""
"""Task object for question answering."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
QuestionAnsweringTask
,
self
).
__init__
(
params
)
...
...
@@ -56,6 +64,14 @@ class QuestionAnsweringTask(base_task.Task):
else
:
self
.
_hub_module
=
None
if
params
.
validation_data
.
tokenization
==
'WordPiece'
:
self
.
squad_lib
=
squad_lib_wp
elif
params
.
validation_data
.
tokenization
==
'SentencePiece'
:
self
.
squad_lib
=
squad_lib_sp
else
:
raise
ValueError
(
'Unsupported tokenization method: {}'
.
format
(
params
.
validation_data
.
tokenization
))
def
build_model
(
self
):
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
...
...
@@ -85,9 +101,53 @@ class QuestionAnsweringTask(base_task.Task):
loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
return
loss
def
_preprocess_eval_data
(
self
,
params
):
eval_examples
=
self
.
squad_lib
.
read_squad_examples
(
input_file
=
params
.
input_path
,
is_training
=
False
,
version_2_with_negative
=
params
.
version_2_with_negative
)
temp_file_path
=
params
.
input_preprocessed_data_path
or
'/tmp'
eval_writer
=
self
.
squad_lib
.
FeatureWriter
(
filename
=
os
.
path
.
join
(
temp_file_path
,
'eval.tf_record'
),
is_training
=
False
)
eval_features
=
[]
def
_append_feature
(
feature
,
is_padding
):
if
not
is_padding
:
eval_features
.
append
(
feature
)
eval_writer
.
process_feature
(
feature
)
kwargs
=
dict
(
examples
=
eval_examples
,
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
params
.
vocab_file
,
do_lower_case
=
params
.
do_lower_case
),
max_seq_length
=
params
.
seq_length
,
doc_stride
=
params
.
doc_stride
,
max_query_length
=
params
.
query_length
,
is_training
=
False
,
output_fn
=
_append_feature
,
batch_size
=
params
.
global_batch_size
)
if
params
.
tokenization
==
'SentencePiece'
:
# squad_lib_sp requires one more argument 'do_lower_case'.
kwargs
[
'do_lower_case'
]
=
params
.
do_lower_case
eval_dataset_size
=
self
.
squad_lib
.
convert_examples_to_features
(
**
kwargs
)
eval_writer
.
close
()
logging
.
info
(
'***** Evaluation input stats *****'
)
logging
.
info
(
' Num orig examples = %d'
,
len
(
eval_examples
))
logging
.
info
(
' Num split examples = %d'
,
len
(
eval_features
))
logging
.
info
(
' Batch size = %d'
,
params
.
global_batch_size
)
logging
.
info
(
' Dataset size = %d'
,
eval_dataset_size
)
return
eval_writer
.
filename
,
eval_examples
,
eval_features
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
# Dummy training data for unit test.
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
...
...
@@ -105,11 +165,17 @@ class QuestionAnsweringTask(base_task.Task):
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
if
params
.
is_training
:
input_path
=
params
.
input_path
else
:
input_path
,
self
.
_eval_examples
,
self
.
_eval_features
=
(
self
.
_preprocess_eval_data
(
params
))
batch_size
=
input_context
.
get_per_replica_batch_size
(
params
.
global_batch_size
)
if
input_context
else
params
.
global_batch_size
# TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset
=
input_pipeline
.
create_squad_dataset
(
params
.
input_path
,
input_path
,
params
.
seq_length
,
batch_size
,
is_training
=
params
.
is_training
,
...
...
@@ -141,6 +207,70 @@ class QuestionAnsweringTask(base_task.Task):
y_true
=
labels
,
# labels has keys 'start_positions' and 'end_positions'.
y_pred
=
{
'start_positions'
:
start_logits
,
'end_positions'
:
end_logits
})
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
features
,
_
=
inputs
unique_ids
=
features
.
pop
(
'unique_ids'
)
model_outputs
=
self
.
inference_step
(
features
,
model
)
start_logits
,
end_logits
=
model_outputs
logs
=
{
self
.
loss
:
0.0
,
# TODO(lehou): compute the real validation loss.
'unique_ids'
:
unique_ids
,
'start_logits'
:
start_logits
,
'end_logits'
:
end_logits
,
}
return
logs
raw_aggregated_result
=
collections
.
namedtuple
(
'RawResult'
,
[
'unique_id'
,
'start_logits'
,
'end_logits'
])
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
assert
step_outputs
is
not
None
,
'Got no logs from self.validation_step.'
if
state
is
None
:
state
=
[]
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
step_outputs
[
'unique_ids'
],
step_outputs
[
'start_logits'
],
step_outputs
[
'end_logits'
]):
u_ids
,
s_logits
,
e_logits
=
(
unique_ids
.
numpy
(),
start_logits
.
numpy
(),
end_logits
.
numpy
())
if
u_ids
.
size
==
1
:
u_ids
=
[
u_ids
]
s_logits
=
[
s_logits
]
e_logits
=
[
e_logits
]
for
values
in
zip
(
u_ids
,
s_logits
,
e_logits
):
state
.
append
(
self
.
raw_aggregated_result
(
unique_id
=
values
[
0
],
start_logits
=
values
[
1
].
tolist
(),
end_logits
=
values
[
2
].
tolist
()))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
all_predictions
,
_
,
scores_diff
=
(
self
.
squad_lib
.
postprocess_output
(
self
.
_eval_examples
,
self
.
_eval_features
,
aggregated_logs
,
self
.
task_config
.
n_best_size
,
self
.
task_config
.
max_answer_length
,
self
.
task_config
.
validation_data
.
do_lower_case
,
version_2_with_negative
=
(
self
.
task_config
.
validation_data
.
version_2_with_negative
),
null_score_diff_threshold
=
(
self
.
task_config
.
null_score_diff_threshold
),
verbose
=
False
))
with
tf
.
io
.
gfile
.
GFile
(
self
.
task_config
.
validation_data
.
input_path
,
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
pred_dataset
=
dataset_json
[
'data'
]
if
self
.
task_config
.
validation_data
.
version_2_with_negative
:
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
pred_dataset
,
all_predictions
,
scores_diff
)
else
:
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
return
eval_metrics
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
...
...
official/nlp/tasks/question_answering_test.py
View file @
db39ef82
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.question_answering."""
import
functools
import
itertools
import
json
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
...
...
@@ -25,30 +27,67 @@ from official.nlp.configs import encoders
from
official.nlp.tasks
import
question_answering
class
QuestionAnsweringTaskTest
(
tf
.
test
.
TestCase
):
class
QuestionAnsweringTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
self
.
_train_data_config
=
bert
.
QADataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
val_data
=
{
"version"
:
"1.1"
,
"data"
:
[{
"paragraphs"
:
[
{
"context"
:
"Sky is blue."
,
"qas"
:
[{
"question"
:
"What is blue?"
,
"id"
:
"1234"
,
"answers"
:
[{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
}]
}]}]}]}
self
.
_val_input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"val_data.json"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
_val_input_path
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
val_data
,
indent
=
4
)
+
"
\n
"
)
self
.
_test_vocab
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
_test_vocab
,
"w"
)
as
writer
:
writer
.
write
(
"[PAD]
\n
[UNK]
\n
[CLS]
\n
[SEP]
\n
[MASK]
\n
sky
\n
is
\n
blue
\n
"
)
def
_get_validation_data_config
(
self
,
version_2_with_negative
=
False
):
return
bert
.
QADevDataConfig
(
input_path
=
self
.
_val_input_path
,
input_preprocessed_data_path
=
self
.
get_temp_dir
(),
seq_length
=
128
,
global_batch_size
=
1
,
version_2_with_negative
=
version_2_with_negative
,
vocab_file
=
self
.
_test_vocab
,
tokenization
=
"WordPiece"
,
do_lower_case
=
True
)
def
_run_task
(
self
,
config
):
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
task
.
initialize
(
model
)
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
functools
.
partial
(
task
.
build_inputs
,
config
.
train_data
))
iterator
=
iter
(
dataset
)
train_dataset
=
task
.
build_inputs
(
config
.
train_data
)
train_iterator
=
iter
(
train_dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
test_task
(
self
):
task
.
train_step
(
next
(
train_iterator
),
model
,
optimizer
,
metrics
=
metrics
)
val_dataset
=
task
.
build_inputs
(
config
.
validation_data
)
val_iterator
=
iter
(
val_dataset
)
logs
=
task
.
validation_step
(
next
(
val_iterator
),
model
,
metrics
=
metrics
)
logs
=
task
.
aggregate_logs
(
step_outputs
=
logs
)
metrics
=
task
.
reduce_aggregated_logs
(
logs
)
self
.
assertIn
(
"final_f1"
,
metrics
)
@
parameterized
.
parameters
(
itertools
.
product
(
(
False
,
True
),
(
"WordPiece"
,
"SentencePiece"
),
))
def
test_task
(
self
,
version_2_with_negative
,
tokenization
):
# Saves a checkpoint.
pretrain_cfg
=
bert
.
BertPretrainerConfig
(
encoder
=
self
.
_encoder_config
,
...
...
@@ -65,22 +104,16 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
config
=
question_answering
.
QuestionAnsweringConfig
(
init_checkpoint
=
saved_path
,
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
initialize
(
model
)
train_data
=
self
.
_train_data_config
,
validation_data
=
self
.
_get_validation_data_config
(
version_2_with_negative
))
self
.
_run_task
(
config
)
def
test_task_with_fit
(
self
):
config
=
question_answering
.
QuestionAnsweringConfig
(
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
train_data
=
self
.
_train_data_config
,
validation_data
=
self
.
_get_validation_data_config
())
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
compile_model
(
...
...
@@ -122,7 +155,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase):
config
=
question_answering
.
QuestionAnsweringConfig
(
hub_module_url
=
hub_module_url
,
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
train_data
=
self
.
_train_data_config
,
validation_data
=
self
.
_get_validation_data_config
())
self
.
_run_task
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment