Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9c9aec17
Commit
9c9aec17
authored
Dec 20, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Dec 20, 2019
Browse files
Support to run ALBERT on SQuAD task.
PiperOrigin-RevId: 286637307
parent
553a4f41
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
953 additions
and
35 deletions
+953
-35
official/nlp/bert/create_finetuning_data.py
official/nlp/bert/create_finetuning_data.py
+36
-10
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+36
-15
official/nlp/bert/squad_lib_sp.py
official/nlp/bert/squad_lib_sp.py
+868
-0
official/nlp/bert/tokenization.py
official/nlp/bert/tokenization.py
+13
-10
No files found.
official/nlp/bert/create_finetuning_data.py
View file @
9c9aec17
...
...
@@ -25,7 +25,10 @@ from absl import flags
import
tensorflow
as
tf
from
official.nlp.bert
import
classifier_data_lib
from
official.nlp.bert
import
squad_lib
# word-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib_sp
FLAGS
=
flags
.
FLAGS
...
...
@@ -70,14 +73,12 @@ flags.DEFINE_string("vocab_file", None,
flags
.
DEFINE_string
(
"train_data_output_path"
,
None
,
"The path in which generated training input data will be written as tf"
" records."
)
" records."
)
flags
.
DEFINE_string
(
"eval_data_output_path"
,
None
,
"The path in which generated training input data will be written as tf"
" records."
)
" records."
)
flags
.
DEFINE_string
(
"meta_data_file_path"
,
None
,
"The path in which input meta data will be written."
)
...
...
@@ -93,6 +94,15 @@ flags.DEFINE_integer(
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded."
)
flags
.
DEFINE_string
(
"sp_model_file"
,
""
,
"The path to the model used by sentence piece tokenizer."
)
flags
.
DEFINE_enum
(
"tokenizer_impl"
,
"word_piece"
,
[
"word_piece"
,
"sentence_piece"
],
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer."
)
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
...
...
@@ -124,13 +134,30 @@ def generate_classifier_dataset():
def
generate_squad_dataset
():
"""Generates squad training dataset and returns input meta data."""
assert
FLAGS
.
squad_data_file
return
squad_lib
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
return
squad_lib_wp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
else
:
assert
FLAGS
.
tokenizer_impl
==
"sentence_piece"
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
sp_model_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
raise
ValueError
(
"FLAG vocab_file for word-piece tokenizer is not specified."
)
else
:
assert
FLAGS
.
tokenizer_impl
==
"sentence_piece"
if
not
FLAGS
.
sp_model_file
:
raise
ValueError
(
"FLAG sp_model_file for sentence-piece tokenizer is not specified."
)
if
FLAGS
.
fine_tuning_task_type
==
"classification"
:
input_meta_data
=
generate_classifier_dataset
()
else
:
...
...
@@ -141,7 +168,6 @@ def main(_):
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"train_data_output_path"
)
flags
.
mark_flag_as_required
(
"meta_data_file_path"
)
app
.
run
(
main
)
official/nlp/bert/run_squad.py
View file @
9c9aec17
...
...
@@ -34,7 +34,10 @@ from official.nlp import optimization
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
squad_lib
# word-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
from
official.nlp.bert
import
squad_lib_sp
from
official.nlp.bert
import
tokenization
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
...
...
@@ -80,11 +83,22 @@ flags.DEFINE_integer(
'max_answer_length'
,
30
,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_string
(
'sp_model_file'
,
None
,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.'
)
common_flags
.
define_common_bert_flags
()
FLAGS
=
flags
.
FLAGS
MODEL_CLASSES
=
{
'bert'
:
(
modeling
.
BertConfig
,
squad_lib_wp
,
tokenization
.
FullTokenizer
),
'albert'
:
(
modeling
.
AlbertConfig
,
squad_lib_sp
,
tokenization
.
FullSentencePieceTokenizer
),
}
def
squad_loss_fn
(
start_positions
,
end_positions
,
...
...
@@ -121,6 +135,7 @@ def get_loss_fn(loss_factor=1.0):
def
get_raw_results
(
predictions
):
"""Converts multi-replica predictions to RawResult."""
squad_lib
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
1
]
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
predictions
[
'unique_ids'
],
predictions
[
'start_logits'
],
predictions
[
'end_logits'
]):
...
...
@@ -167,9 +182,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
...
...
@@ -219,7 +232,8 @@ def train_squad(strategy,
if
use_float16
:
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
0
].
from_json_file
(
FLAGS
.
bert_config_file
)
epochs
=
FLAGS
.
num_train_epochs
num_train_examples
=
input_meta_data
[
'train_data_size'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
...
...
@@ -281,7 +295,14 @@ def train_squad(strategy,
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for a squad dataset."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
config_cls
,
squad_lib
,
tokenizer_cls
=
MODEL_CLASSES
[
FLAGS
.
model_type
]
bert_config
=
config_cls
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
tokenizer_cls
==
tokenization
.
FullTokenizer
:
tokenizer
=
tokenizer_cls
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
else
:
assert
tokenizer_cls
==
tokenization
.
FullSentencePieceTokenizer
tokenizer
=
tokenizer_cls
(
sp_model_file
=
FLAGS
.
sp_model_file
)
doc_stride
=
input_meta_data
[
'doc_stride'
]
max_query_length
=
input_meta_data
[
'max_query_length'
]
# Whether data should be in Ver 2.0 format.
...
...
@@ -292,9 +313,6 @@ def predict_squad(strategy, input_meta_data):
is_training
=
False
,
version_2_with_negative
=
version_2_with_negative
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
eval_writer
=
squad_lib
.
FeatureWriter
(
filename
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'eval.tf_record'
),
is_training
=
False
)
...
...
@@ -309,7 +327,7 @@ def predict_squad(strategy, input_meta_data):
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
dataset_size
=
squad_lib
.
convert_examples_to_features
(
kwargs
=
dict
(
examples
=
eval_examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
input_meta_data
[
'max_seq_length'
],
...
...
@@ -318,6 +336,11 @@ def predict_squad(strategy, input_meta_data):
is_training
=
False
,
output_fn
=
_append_feature
,
batch_size
=
FLAGS
.
predict_batch_size
)
# squad_lib_sp requires one more argument 'do_lower_case'.
if
squad_lib
==
squad_lib_sp
:
kwargs
[
'do_lower_case'
]
=
FLAGS
.
do_lower_case
dataset_size
=
squad_lib
.
convert_examples_to_features
(
**
kwargs
)
eval_writer
.
close
()
logging
.
info
(
'***** Running predictions *****'
)
...
...
@@ -358,12 +381,10 @@ def export_squad(model_export_path, input_meta_data):
"""
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
MODEL_CLASSES
[
FLAGS
.
model_type
][
0
].
from_json_file
(
FLAGS
.
bert_config_file
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
...
official/nlp/bert/squad_lib_sp.py
0 → 100644
View file @
9c9aec17
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
The file is forked from:
https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
copy
import
json
import
math
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
class
SquadExample
(
object
):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def
__init__
(
self
,
qas_id
,
question_text
,
paragraph_text
,
orig_answer_text
=
None
,
start_position
=
None
,
end_position
=
None
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
paragraph_text
=
paragraph_text
self
.
orig_answer_text
=
orig_answer_text
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
__str__
(
self
):
return
self
.
__repr__
()
def
__repr__
(
self
):
s
=
""
s
+=
"qas_id: %s"
%
(
tokenization
.
printable_text
(
self
.
qas_id
))
s
+=
", question_text: %s"
%
(
tokenization
.
printable_text
(
self
.
question_text
))
s
+=
", paragraph_text: [%s]"
%
(
" "
.
join
(
self
.
paragraph_text
))
if
self
.
start_position
:
s
+=
", start_position: %d"
%
(
self
.
start_position
)
if
self
.
start_position
:
s
+=
", end_position: %d"
%
(
self
.
end_position
)
if
self
.
start_position
:
s
+=
", is_impossible: %r"
%
(
self
.
is_impossible
)
return
s
class
InputFeatures
(
object
):
"""A single set of features of data."""
def
__init__
(
self
,
unique_id
,
example_index
,
doc_span_index
,
tok_start_to_orig_index
,
tok_end_to_orig_index
,
token_is_max_context
,
tokens
,
input_ids
,
input_mask
,
segment_ids
,
paragraph_len
,
start_position
=
None
,
end_position
=
None
,
is_impossible
=
None
):
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
self
.
tok_start_to_orig_index
=
tok_start_to_orig_index
self
.
tok_end_to_orig_index
=
tok_end_to_orig_index
self
.
token_is_max_context
=
token_is_max_context
self
.
tokens
=
tokens
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
paragraph_len
=
paragraph_len
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
"""Read a SQuAD json file into a list of SquadExample."""
del
version_2_with_negative
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
examples
=
[]
for
entry
in
input_data
:
for
paragraph
in
entry
[
"paragraphs"
]:
paragraph_text
=
paragraph
[
"context"
]
for
qa
in
paragraph
[
"qas"
]:
qas_id
=
qa
[
"id"
]
question_text
=
qa
[
"question"
]
start_position
=
None
orig_answer_text
=
None
is_impossible
=
False
if
is_training
:
is_impossible
=
qa
.
get
(
"is_impossible"
,
False
)
if
(
len
(
qa
[
"answers"
])
!=
1
)
and
(
not
is_impossible
):
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
if
not
is_impossible
:
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
start_position
=
answer
[
"answer_start"
]
else
:
start_position
=
-
1
orig_answer_text
=
""
example
=
SquadExample
(
qas_id
=
qas_id
,
question_text
=
question_text
,
paragraph_text
=
paragraph_text
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
is_impossible
=
is_impossible
)
examples
.
append
(
example
)
return
examples
def
_convert_index
(
index
,
pos
,
m
=
None
,
is_start
=
True
):
"""Converts index."""
if
index
[
pos
]
is
not
None
:
return
index
[
pos
]
n
=
len
(
index
)
rear
=
pos
while
rear
<
n
-
1
and
index
[
rear
]
is
None
:
rear
+=
1
front
=
pos
while
front
>
0
and
index
[
front
]
is
None
:
front
-=
1
assert
index
[
front
]
is
not
None
or
index
[
rear
]
is
not
None
if
index
[
front
]
is
None
:
if
index
[
rear
]
>=
1
:
if
is_start
:
return
0
else
:
return
index
[
rear
]
-
1
return
index
[
rear
]
if
index
[
rear
]
is
None
:
if
m
is
not
None
and
index
[
front
]
<
m
-
1
:
if
is_start
:
return
index
[
front
]
+
1
else
:
return
m
-
1
return
index
[
front
]
if
is_start
:
if
index
[
rear
]
>
index
[
front
]
+
1
:
return
index
[
front
]
+
1
else
:
return
index
[
rear
]
else
:
if
index
[
rear
]
>
index
[
front
]
+
1
:
return
index
[
rear
]
-
1
else
:
return
index
[
front
]
def
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
output_fn
,
do_lower_case
,
batch_size
=
None
):
"""Loads a data file into a list of `InputBatch`s."""
cnt_pos
,
cnt_neg
=
0
,
0
base_id
=
1000000000
unique_id
=
base_id
max_n
,
max_m
=
1024
,
1024
f
=
np
.
zeros
((
max_n
,
max_m
),
dtype
=
np
.
float32
)
for
(
example_index
,
example
)
in
enumerate
(
examples
):
if
example_index
%
100
==
0
:
logging
.
info
(
"Converting %d/%d pos %d neg %d"
,
example_index
,
len
(
examples
),
cnt_pos
,
cnt_neg
)
query_tokens
=
tokenization
.
encode_ids
(
tokenizer
.
sp_model
,
tokenization
.
preprocess_text
(
example
.
question_text
,
lower
=
do_lower_case
))
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
paragraph_text
=
example
.
paragraph_text
para_tokens
=
tokenization
.
encode_pieces
(
tokenizer
.
sp_model
,
tokenization
.
preprocess_text
(
example
.
paragraph_text
,
lower
=
do_lower_case
))
chartok_to_tok_index
=
[]
tok_start_to_chartok_index
=
[]
tok_end_to_chartok_index
=
[]
char_cnt
=
0
for
i
,
token
in
enumerate
(
para_tokens
):
new_token
=
token
.
replace
(
tokenization
.
SPIECE_UNDERLINE
,
" "
)
chartok_to_tok_index
.
extend
([
i
]
*
len
(
new_token
))
tok_start_to_chartok_index
.
append
(
char_cnt
)
char_cnt
+=
len
(
new_token
)
tok_end_to_chartok_index
.
append
(
char_cnt
-
1
)
tok_cat_text
=
""
.
join
(
para_tokens
).
replace
(
tokenization
.
SPIECE_UNDERLINE
,
" "
)
n
,
m
=
len
(
paragraph_text
),
len
(
tok_cat_text
)
if
n
>
max_n
or
m
>
max_m
:
max_n
=
max
(
n
,
max_n
)
max_m
=
max
(
m
,
max_m
)
f
=
np
.
zeros
((
max_n
,
max_m
),
dtype
=
np
.
float32
)
g
=
{}
# pylint: disable=cell-var-from-loop
def
_lcs_match
(
max_dist
,
n
=
n
,
m
=
m
):
"""Longest-common-substring algorithm."""
f
.
fill
(
0
)
g
.
clear
()
### longest common sub sequence
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
for
i
in
range
(
n
):
# unlike standard LCS, this is specifically optimized for the setting
# because the mismatch between sentence pieces and original text will
# be small
for
j
in
range
(
i
-
max_dist
,
i
+
max_dist
):
if
j
>=
m
or
j
<
0
:
continue
if
i
>
0
:
g
[(
i
,
j
)]
=
0
f
[
i
,
j
]
=
f
[
i
-
1
,
j
]
if
j
>
0
and
f
[
i
,
j
-
1
]
>
f
[
i
,
j
]:
g
[(
i
,
j
)]
=
1
f
[
i
,
j
]
=
f
[
i
,
j
-
1
]
f_prev
=
f
[
i
-
1
,
j
-
1
]
if
i
>
0
and
j
>
0
else
0
if
(
tokenization
.
preprocess_text
(
paragraph_text
[
i
],
lower
=
do_lower_case
,
remove_space
=
False
)
==
tok_cat_text
[
j
]
and
f_prev
+
1
>
f
[
i
,
j
]):
g
[(
i
,
j
)]
=
2
f
[
i
,
j
]
=
f_prev
+
1
# pylint: enable=cell-var-from-loop
max_dist
=
abs
(
n
-
m
)
+
5
for
_
in
range
(
2
):
_lcs_match
(
max_dist
)
if
f
[
n
-
1
,
m
-
1
]
>
0.8
*
n
:
break
max_dist
*=
2
orig_to_chartok_index
=
[
None
]
*
n
chartok_to_orig_index
=
[
None
]
*
m
i
,
j
=
n
-
1
,
m
-
1
while
i
>=
0
and
j
>=
0
:
if
(
i
,
j
)
not
in
g
:
break
if
g
[(
i
,
j
)]
==
2
:
orig_to_chartok_index
[
i
]
=
j
chartok_to_orig_index
[
j
]
=
i
i
,
j
=
i
-
1
,
j
-
1
elif
g
[(
i
,
j
)]
==
1
:
j
=
j
-
1
else
:
i
=
i
-
1
if
(
all
(
v
is
None
for
v
in
orig_to_chartok_index
)
or
f
[
n
-
1
,
m
-
1
]
<
0.8
*
n
):
logging
.
info
(
"MISMATCH DETECTED!"
)
continue
tok_start_to_orig_index
=
[]
tok_end_to_orig_index
=
[]
for
i
in
range
(
len
(
para_tokens
)):
start_chartok_pos
=
tok_start_to_chartok_index
[
i
]
end_chartok_pos
=
tok_end_to_chartok_index
[
i
]
start_orig_pos
=
_convert_index
(
chartok_to_orig_index
,
start_chartok_pos
,
n
,
is_start
=
True
)
end_orig_pos
=
_convert_index
(
chartok_to_orig_index
,
end_chartok_pos
,
n
,
is_start
=
False
)
tok_start_to_orig_index
.
append
(
start_orig_pos
)
tok_end_to_orig_index
.
append
(
end_orig_pos
)
if
not
is_training
:
tok_start_position
=
tok_end_position
=
None
if
is_training
and
example
.
is_impossible
:
tok_start_position
=
0
tok_end_position
=
0
if
is_training
and
not
example
.
is_impossible
:
start_position
=
example
.
start_position
end_position
=
start_position
+
len
(
example
.
orig_answer_text
)
-
1
start_chartok_pos
=
_convert_index
(
orig_to_chartok_index
,
start_position
,
is_start
=
True
)
tok_start_position
=
chartok_to_tok_index
[
start_chartok_pos
]
end_chartok_pos
=
_convert_index
(
orig_to_chartok_index
,
end_position
,
is_start
=
False
)
tok_end_position
=
chartok_to_tok_index
[
end_chartok_pos
]
assert
tok_start_position
<=
tok_end_position
def
_piece_to_id
(
x
):
return
tokenizer
.
sp_model
.
PieceToId
(
x
)
all_doc_tokens
=
list
(
map
(
_piece_to_id
,
para_tokens
))
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"DocSpan"
,
[
"start"
,
"length"
])
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
break
start_offset
+=
min
(
length
,
doc_stride
)
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_is_max_context
=
{}
segment_ids
=
[]
cur_tok_start_to_orig_index
=
[]
cur_tok_end_to_orig_index
=
[]
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[CLS]"
))
segment_ids
.
append
(
0
)
for
token
in
query_tokens
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
segment_ids
.
append
(
0
)
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
cur_tok_start_to_orig_index
.
append
(
tok_start_to_orig_index
[
split_token_index
])
cur_tok_end_to_orig_index
.
append
(
tok_end_to_orig_index
[
split_token_index
])
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
segment_ids
.
append
(
1
)
paragraph_len
=
len
(
tokens
)
input_ids
=
tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
span_is_impossible
=
example
.
is_impossible
start_position
=
None
end_position
=
None
if
is_training
and
not
span_is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
# continue
start_position
=
0
end_position
=
0
span_is_impossible
=
True
else
:
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
span_is_impossible
:
start_position
=
0
end_position
=
0
if
example_index
<
20
:
logging
.
info
(
"*** Example ***"
)
logging
.
info
(
"unique_id: %s"
,
(
unique_id
))
logging
.
info
(
"example_index: %s"
,
(
example_index
))
logging
.
info
(
"doc_span_index: %s"
,
(
doc_span_index
))
logging
.
info
(
"tok_start_to_orig_index: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
cur_tok_start_to_orig_index
]))
logging
.
info
(
"tok_end_to_orig_index: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
cur_tok_end_to_orig_index
]))
logging
.
info
(
"token_is_max_context: %s"
,
" "
.
join
(
[
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
token_is_max_context
.
items
()]))
logging
.
info
(
"input_pieces: %s"
,
" "
.
join
([
tokenizer
.
sp_model
.
IdToPiece
(
x
)
for
x
in
tokens
]))
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
and
span_is_impossible
:
logging
.
info
(
"impossible example span"
)
if
is_training
and
not
span_is_impossible
:
pieces
=
[
tokenizer
.
sp_model
.
IdToPiece
(
token
)
for
token
in
tokens
[
start_position
:(
end_position
+
1
)]
]
answer_text
=
tokenizer
.
sp_model
.
DecodePieces
(
pieces
)
logging
.
info
(
"start_position: %d"
,
(
start_position
))
logging
.
info
(
"end_position: %d"
,
(
end_position
))
logging
.
info
(
"answer: %s"
,
(
tokenization
.
printable_text
(
answer_text
)))
# With multi processing, the example_index is actually the index
# within the current process therefore we use example_index=None
# to avoid being used in the future.
# The current code does not use example_index of training data.
if
is_training
:
feat_example_index
=
None
else
:
feat_example_index
=
example_index
feature
=
InputFeatures
(
unique_id
=
unique_id
,
example_index
=
feat_example_index
,
doc_span_index
=
doc_span_index
,
tok_start_to_orig_index
=
cur_tok_start_to_orig_index
,
tok_end_to_orig_index
=
cur_tok_end_to_orig_index
,
token_is_max_context
=
token_is_max_context
,
tokens
=
[
tokenizer
.
sp_model
.
IdToPiece
(
x
)
for
x
in
tokens
],
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
paragraph_len
=
paragraph_len
,
start_position
=
start_position
,
end_position
=
end_position
,
is_impossible
=
span_is_impossible
)
# Run callback
if
is_training
:
output_fn
(
feature
)
else
:
output_fn
(
feature
,
is_padding
=
False
)
unique_id
+=
1
if
span_is_impossible
:
cnt_neg
+=
1
else
:
cnt_pos
+=
1
if
not
is_training
and
feature
:
assert
batch_size
num_padding
=
0
num_examples
=
unique_id
-
base_id
if
unique_id
%
batch_size
!=
0
:
num_padding
=
batch_size
-
(
num_examples
%
batch_size
)
dummy_feature
=
copy
.
deepcopy
(
feature
)
for
_
in
range
(
num_padding
):
dummy_feature
.
unique_id
=
unique_id
# Run callback
output_fn
(
feature
,
is_padding
=
True
)
unique_id
+=
1
logging
.
info
(
"Total number of instances: %d = pos %d neg %d"
,
cnt_pos
+
cnt_neg
,
cnt_pos
,
cnt_neg
)
return
unique_id
-
base_id
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span B: to the store and bought
# Span C: and bought a gallon of
# ...
#
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
#
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
best_score
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
position
<
doc_span
.
start
:
continue
if
position
>
end
:
continue
num_left_context
=
position
-
doc_span
.
start
num_right_context
=
end
-
position
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
.
length
if
best_score
is
None
or
score
>
best_score
:
best_score
=
score
best_span_index
=
span_index
return
cur_span_index
==
best_span_index
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
version_2_with_negative
=
False
,
null_score_diff_threshold
=
0.0
,
verbose
=
False
):
"""Write final predictions to the json file and log-odds of null if needed."""
del
do_lower_case
,
verbose
logging
.
info
(
"Writing predictions to: %s"
,
(
output_prediction_file
))
logging
.
info
(
"Writing nbest to: %s"
,
(
output_nbest_file
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
unique_id_to_result
=
{}
for
result
in
all_results
:
unique_id_to_result
[
result
.
unique_id
]
=
result
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
])
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
prelim_predictions
=
[]
# keep track of the minimum score of null start+end of position 0
score_null
=
1000000
# large and positive
min_null_feature_index
=
0
# the paragraph slice with min mull score
null_start_logit
=
0
# the start logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
if
version_2_with_negative
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
doc_offset
=
feature
.
tokens
.
index
(
"[SEP]"
)
+
1
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if
start_index
-
doc_offset
>=
len
(
feature
.
tok_start_to_orig_index
):
continue
if
end_index
-
doc_offset
>=
len
(
feature
.
tok_end_to_orig_index
):
continue
# if start_index not in feature.tok_start_to_orig_index:
# continue
# if end_index not in feature.tok_end_to_orig_index:
# continue
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
if
end_index
<
start_index
:
continue
length
=
end_index
-
start_index
+
1
if
length
>
max_answer_length
:
continue
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
feature_index
,
start_index
=
start_index
-
doc_offset
,
end_index
=
end_index
-
doc_offset
,
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
if
version_2_with_negative
:
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
start_index
=-
1
,
end_index
=-
1
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
reverse
=
True
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
])
seen_predictions
=
{}
nbest
=
[]
for
pred
in
prelim_predictions
:
if
len
(
nbest
)
>=
n_best_size
:
break
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>=
0
:
# this is a non-null prediction
tok_start_to_orig_index
=
feature
.
tok_start_to_orig_index
tok_end_to_orig_index
=
feature
.
tok_end_to_orig_index
start_orig_pos
=
tok_start_to_orig_index
[
pred
.
start_index
]
end_orig_pos
=
tok_end_to_orig_index
[
pred
.
end_index
]
paragraph_text
=
example
.
paragraph_text
final_text
=
paragraph_text
[
start_orig_pos
:
end_orig_pos
+
1
].
strip
()
if
final_text
in
seen_predictions
:
continue
seen_predictions
[
final_text
]
=
True
else
:
final_text
=
""
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, inlcude it
if
version_2_with_negative
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
assert
len
(
nbest
)
>=
1
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
if
not
best_non_null_entry
:
if
entry
.
text
:
best_non_null_entry
=
entry
probs
=
_compute_softmax
(
total_scores
)
nbest_json
=
[]
for
(
i
,
entry
)
in
enumerate
(
nbest
):
output
=
collections
.
OrderedDict
()
output
[
"text"
]
=
entry
.
text
output
[
"probability"
]
=
probs
[
i
]
output
[
"start_logit"
]
=
entry
.
start_logit
output
[
"end_logit"
]
=
entry
.
end_logit
nbest_json
.
append
(
output
)
assert
len
(
nbest_json
)
>=
1
if
not
version_2_with_negative
:
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
assert
best_non_null_entry
is
not
None
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
io
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
tf
.
io
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
tf
.
io
.
gfile
.
GFile
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
if
not
scores
:
return
[]
max_score
=
None
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
exp_scores
=
[]
total_sum
=
0.0
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
probs
=
[]
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
return
probs
class
FeatureWriter
(
object
):
"""Writes InputFeature to TF example file."""
def
__init__
(
self
,
filename
,
is_training
):
self
.
filename
=
filename
self
.
is_training
=
is_training
self
.
num_features
=
0
self
.
_writer
=
tf
.
io
.
TFRecordWriter
(
filename
)
def
process_feature
(
self
,
feature
):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
self
.
num_features
+=
1
def
create_int_feature
(
values
):
feature
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
feature
features
=
collections
.
OrderedDict
()
features
[
"unique_ids"
]
=
create_int_feature
([
feature
.
unique_id
])
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
self
.
is_training
:
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
features
[
"end_positions"
]
=
create_int_feature
([
feature
.
end_position
])
impossible
=
0
if
feature
.
is_impossible
:
impossible
=
1
features
[
"is_impossible"
]
=
create_int_feature
([
impossible
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
self
.
_writer
.
write
(
tf_example
.
SerializeToString
())
def
close
(
self
):
self
.
_writer
.
close
()
def
generate_tf_record_from_json_file
(
input_file_path
,
sp_model_file
,
output_path
,
max_seq_length
=
384
,
do_lower_case
=
True
,
max_query_length
=
64
,
doc_stride
=
128
,
version_2_with_negative
=
False
):
"""Generates and saves training data into a tf record file."""
train_examples
=
read_squad_examples
(
input_file
=
input_file_path
,
is_training
=
True
,
version_2_with_negative
=
version_2_with_negative
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
sp_model_file
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
)
number_of_examples
=
convert_examples_to_features
(
examples
=
train_examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
max_seq_length
,
doc_stride
=
doc_stride
,
max_query_length
=
max_query_length
,
is_training
=
True
,
output_fn
=
train_writer
.
process_feature
,
do_lower_case
=
do_lower_case
)
train_writer
.
close
()
meta_data
=
{
"task_type"
:
"bert_squad"
,
"train_data_size"
:
number_of_examples
,
"max_seq_length"
:
max_seq_length
,
"max_query_length"
:
max_query_length
,
"doc_stride"
:
doc_stride
,
"version_2_with_negative"
:
version_2_with_negative
,
}
return
meta_data
official/nlp/bert/tokenization.py
View file @
9c9aec17
...
...
@@ -32,7 +32,7 @@ import tensorflow as tf
import
sentencepiece
as
spm
SPIECE_UNDERLINE
=
u
"▁"
.
encode
(
"utf-8"
)
SPIECE_UNDERLINE
=
"▁"
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
...
...
@@ -458,6 +458,9 @@ def encode_pieces(sp_model, text, sample=False):
Returns:
A list of token pieces.
"""
if
six
.
PY2
and
isinstance
(
text
,
six
.
text_type
):
text
=
six
.
ensure_binary
(
text
,
"utf-8"
)
if
not
sample
:
pieces
=
sp_model
.
EncodeAsPieces
(
text
)
else
:
...
...
@@ -466,8 +469,8 @@ def encode_pieces(sp_model, text, sample=False):
for
piece
in
pieces
:
piece
=
printable_text
(
piece
)
if
len
(
piece
)
>
1
and
piece
[
-
1
]
==
","
and
piece
[
-
2
].
isdigit
():
cur_pieces
=
sp_model
.
EncodeAsPieces
(
six
.
ensure_binary
(
piece
[:
-
1
]).
replace
(
SPIECE_UNDERLINE
,
b
""
))
cur_pieces
=
sp_model
.
EncodeAsPieces
(
piece
[:
-
1
].
replace
(
SPIECE_UNDERLINE
,
""
))
if
piece
[
0
]
!=
SPIECE_UNDERLINE
and
cur_pieces
[
0
][
0
]
==
SPIECE_UNDERLINE
:
if
len
(
cur_pieces
[
0
])
==
1
:
cur_pieces
=
cur_pieces
[
1
:]
...
...
@@ -514,21 +517,21 @@ class FullSentencePieceTokenizer(object):
Args:
sp_model_file: The path to the sentence piece model file.
"""
self
.
_
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
_
sp_model
.
Load
(
sp_model_file
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
sp_model_file
)
self
.
vocab
=
{
self
.
_
sp_model
.
IdToPiece
(
i
):
i
for
i
in
six
.
moves
.
range
(
self
.
_
sp_model
.
GetPieceSize
())
self
.
sp_model
.
IdToPiece
(
i
):
i
for
i
in
six
.
moves
.
range
(
self
.
sp_model
.
GetPieceSize
())
}
def
tokenize
(
self
,
text
):
"""Tokenizes text into pieces."""
return
encode_pieces
(
self
.
_
sp_model
,
text
)
return
encode_pieces
(
self
.
sp_model
,
text
)
def
convert_tokens_to_ids
(
self
,
tokens
):
"""Converts a list of tokens to a list of ids."""
return
[
self
.
_
sp_model
.
PieceToId
(
printable_text
(
token
))
for
token
in
tokens
]
return
[
self
.
sp_model
.
PieceToId
(
printable_text
(
token
))
for
token
in
tokens
]
def
convert_ids_to_tokens
(
self
,
ids
):
"""Converts a list of ids ot a list of tokens."""
return
[
self
.
_
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
return
[
self
.
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment