Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9c9aec17
Commit
9c9aec17
authored
Dec 20, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 20, 2019
Browse files
Support to run ALBERT on SQuAD task.
PiperOrigin-RevId: 286637307
parent
553a4f41
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
953 additions
and
35 deletions
+953
-35
official/nlp/bert/create_finetuning_data.py
official/nlp/bert/create_finetuning_data.py
+36
-10
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+36
-15
official/nlp/bert/squad_lib_sp.py
official/nlp/bert/squad_lib_sp.py
+868
-0
official/nlp/bert/tokenization.py
official/nlp/bert/tokenization.py
+13
-10
No files found.
official/nlp/bert/create_finetuning_data.py
View file @
9c9aec17
...
...
@@ -25,7 +25,10 @@ from absl import flags
import
tensorflow
as
tf
from
official.nlp.bert
import
classifier_data_lib
from
official.nlp.bert
import
squad_lib
# word-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib_sp
FLAGS
=
flags
.
FLAGS
...
...
@@ -70,14 +73,12 @@ flags.DEFINE_string("vocab_file", None,
flags
.
DEFINE_string
(
"train_data_output_path"
,
None
,
"The path in which generated training input data will be written as tf"
" records."
)
" records."
)
flags
.
DEFINE_string
(
"eval_data_output_path"
,
None
,
"The path in which generated training input data will be written as tf"
" records."
)
" records."
)
flags
.
DEFINE_string
(
"meta_data_file_path"
,
None
,
"The path in which input meta data will be written."
)
...
...
@@ -93,6 +94,15 @@ flags.DEFINE_integer(
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded."
)
flags
.
DEFINE_string
(
"sp_model_file"
,
""
,
"The path to the model used by sentence piece tokenizer."
)
flags
.
DEFINE_enum
(
"tokenizer_impl"
,
"word_piece"
,
[
"word_piece"
,
"sentence_piece"
],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer."
)
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
...
...
@@ -124,13 +134,30 @@ def generate_classifier_dataset():
def
generate_squad_dataset
():
"""Generates squad training dataset and returns input meta data."""
assert
FLAGS
.
squad_data_file
return
squad_lib
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
return
squad_lib_wp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
else
:
assert
FLAGS
.
tokenizer_impl
==
"sentence_piece"
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
sp_model_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
raise
ValueError
(
"FLAG vocab_file for word-piece tokenizer is not specified."
)
else
:
assert
FLAGS
.
tokenizer_impl
==
"sentence_piece"
if
not
FLAGS
.
sp_model_file
:
raise
ValueError
(
"FLAG sp_model_file for sentence-piece tokenizer is not specified."
)
if
FLAGS
.
fine_tuning_task_type
==
"classification"
:
input_meta_data
=
generate_classifier_dataset
()
else
:
...
...
@@ -141,7 +168,6 @@ def main(_):
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"train_data_output_path"
)
flags
.
mark_flag_as_required
(
"meta_data_file_path"
)
app
.
run
(
main
)
official/nlp/bert/run_squad.py
View file @
9c9aec17
...
...
@@ -34,7 +34,10 @@ from official.nlp import optimization
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
squad_lib
# word-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib_sp
from
official.nlp.bert
import
tokenization
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
...
...
@@ -80,11 +83,22 @@ flags.DEFINE_integer(
'max_answer_length'
,
30
,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_string
(
'sp_model_file'
,
None
,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.'
)
common_flags
.
define_common_bert_flags
()
FLAGS
=
flags
.
FLAGS
MODEL_CLASSES
=
{
'bert'
:
(
modeling
.
BertConfig
,
squad_lib_wp
,
tokenization
.
FullTokenizer
),
'albert'
:
(
modeling
.
AlbertConfig
,
squad_lib_sp
,
tokenization
.
FullSentencePieceTokenizer
),
}
def
squad_loss_fn
(
start_positions
,
end_positions
,
...
...
@@ -121,6 +135,7 @@ def get_loss_fn(loss_factor=1.0):
def
get_raw_results
(
predictions
):
"""Converts multi-replica predictions to RawResult."""
squad_lib
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
1
]
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
predictions
[
'unique_ids'
],
predictions
[
'start_logits'
],
predictions
[
'end_logits'
]):
...
...
@@ -167,9 +182,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
...
...
@@ -219,7 +232,8 @@ def train_squad(strategy,
if
use_float16
:
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
0
].
from_json_file
(
FLAGS
.
bert_config_file
)
epochs
=
FLAGS
.
num_train_epochs
num_train_examples
=
input_meta_data
[
'train_data_size'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
...
...
@@ -281,7 +295,14 @@ def train_squad(strategy,
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for a squad dataset."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
config_cls
,
squad_lib
,
tokenizer_cls
=
MODEL_CLASSES
[
FLAGS
.
model_type
]
bert_config
=
config_cls
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
tokenizer_cls
==
tokenization
.
FullTokenizer
:
tokenizer
=
tokenizer_cls
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
else
:
assert
tokenizer_cls
==
tokenization
.
FullSentencePieceTokenizer
tokenizer
=
tokenizer_cls
(
sp_model_file
=
FLAGS
.
sp_model_file
)
doc_stride
=
input_meta_data
[
'doc_stride'
]
max_query_length
=
input_meta_data
[
'max_query_length'
]
# Whether data should be in Ver 2.0 format.
...
...
@@ -292,9 +313,6 @@ def predict_squad(strategy, input_meta_data):
is_training
=
False
,
version_2_with_negative
=
version_2_with_negative
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
eval_writer
=
squad_lib
.
FeatureWriter
(
filename
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'eval.tf_record'
),
is_training
=
False
)
...
...
@@ -309,7 +327,7 @@ def predict_squad(strategy, input_meta_data):
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
dataset_size
=
squad_lib
.
convert_examples_to_features
(
kwargs
=
dict
(
examples
=
eval_examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
input_meta_data
[
'max_seq_length'
],
...
...
@@ -318,6 +336,11 @@ def predict_squad(strategy, input_meta_data):
is_training
=
False
,
output_fn
=
_append_feature
,
batch_size
=
FLAGS
.
predict_batch_size
)
# squad_lib_sp requires one more argument 'do_lower_case'.
if
squad_lib
==
squad_lib_sp
:
kwargs
[
'do_lower_case'
]
=
FLAGS
.
do_lower_case
dataset_size
=
squad_lib
.
convert_examples_to_features
(
**
kwargs
)
eval_writer
.
close
()
logging
.
info
(
'***** Running predictions *****'
)
...
...
@@ -358,12 +381,10 @@ def export_squad(model_export_path, input_meta_data):
"""
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
0
].
from_json_file
(
FLAGS
.
bert_config_file
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
...
official/nlp/bert/squad_lib_sp.py
0 → 100644
View file @
9c9aec17
This diff is collapsed.
Click to expand it.
official/nlp/bert/tokenization.py
View file @
9c9aec17
...
...
@@ -32,7 +32,7 @@ import tensorflow as tf
import
sentencepiece
as
spm
SPIECE_UNDERLINE
=
u
"▁"
.
encode
(
"utf-8"
)
SPIECE_UNDERLINE
=
"▁"
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
...
...
@@ -458,6 +458,9 @@ def encode_pieces(sp_model, text, sample=False):
Returns:
A list of token pieces.
"""
if
six
.
PY2
and
isinstance
(
text
,
six
.
text_type
):
text
=
six
.
ensure_binary
(
text
,
"utf-8"
)
if
not
sample
:
pieces
=
sp_model
.
EncodeAsPieces
(
text
)
else
:
...
...
@@ -466,8 +469,8 @@ def encode_pieces(sp_model, text, sample=False):
for
piece
in
pieces
:
piece
=
printable_text
(
piece
)
if
len
(
piece
)
>
1
and
piece
[
-
1
]
==
","
and
piece
[
-
2
].
isdigit
():
cur_pieces
=
sp_model
.
EncodeAsPieces
(
six
.
ensure_binary
(
piece
[:
-
1
]).
replace
(
SPIECE_UNDERLINE
,
b
""
))
cur_pieces
=
sp_model
.
EncodeAsPieces
(
piece
[:
-
1
].
replace
(
SPIECE_UNDERLINE
,
""
))
if
piece
[
0
]
!=
SPIECE_UNDERLINE
and
cur_pieces
[
0
][
0
]
==
SPIECE_UNDERLINE
:
if
len
(
cur_pieces
[
0
])
==
1
:
cur_pieces
=
cur_pieces
[
1
:]
...
...
@@ -514,21 +517,21 @@ class FullSentencePieceTokenizer(object):
Args:
sp_model_file: The path to the sentence piece model file.
"""
self
.
_
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
_
sp_model
.
Load
(
sp_model_file
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
sp_model_file
)
self
.
vocab
=
{
self
.
_
sp_model
.
IdToPiece
(
i
):
i
for
i
in
six
.
moves
.
range
(
self
.
_
sp_model
.
GetPieceSize
())
self
.
sp_model
.
IdToPiece
(
i
):
i
for
i
in
six
.
moves
.
range
(
self
.
sp_model
.
GetPieceSize
())
}
def
tokenize
(
self
,
text
):
"""Tokenizes text into pieces."""
return
encode_pieces
(
self
.
_
sp_model
,
text
)
return
encode_pieces
(
self
.
sp_model
,
text
)
def
convert_tokens_to_ids
(
self
,
tokens
):
"""Converts a list of tokens to a list of ids."""
return
[
self
.
_
sp_model
.
PieceToId
(
printable_text
(
token
))
for
token
in
tokens
]
return
[
self
.
sp_model
.
PieceToId
(
printable_text
(
token
))
for
token
in
tokens
]
def
convert_ids_to_tokens
(
self
,
ids
):
"""Converts a list of ids ot a list of tokens."""
return
[
self
.
_
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
return
[
self
.
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment