Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
22e7c4ed
Commit
22e7c4ed
authored
Oct 03, 2019
by
erenup
Browse files
fixing for roberta tokenizer decoding
parent
ebb32261
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
17 deletions
+24
-17
examples/run_squad.py
examples/run_squad.py
+2
-2
examples/utils_squad.py
examples/utils_squad.py
+22
-15
No files found.
examples/run_squad.py
View file @
22e7c4ed
...
...
@@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
,
tokenizer
,
args
.
model_type
)
# Evaluate with the official SQuAD script
evaluate_options
=
EVAL_OPTS
(
data_file
=
args
.
predict_file
,
...
...
@@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
)
is_training
=
not
evaluate
,
add_prefix_space
=
True
if
args
.
model_type
==
'roberta'
else
False
)
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
...
...
examples/utils_squad.py
View file @
22e7c4ed
...
...
@@ -25,6 +25,7 @@ import collections
from
io
import
open
from
transformers.tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
transformers.tokenization_roberta
import
RobertaTokenizer
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
from
utils_squad_evaluate
import
find_all_best_thresh_v2
,
make_qid_to_has_ans
,
get_raw_scores
...
...
@@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
):
mask_padding_with_zero
=
True
,
add_prefix_space
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
unique_id
=
1000000000
...
...
@@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# if example_index % 100 == 0:
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
,
add_prefix_space
=
add_prefix_space
)
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
...
...
@@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
sub_tokens
=
tokenizer
.
tokenize
(
token
,
add_prefix_space
=
add_prefix_space
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
...
...
@@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
orig_answer_text
)
example
.
orig_answer_text
,
add_prefix_space
)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
...
...
@@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
orig_answer_text
,
add_prefix_space
):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
...
...
@@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
,
add_prefix_space
=
add_prefix_space
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
...
...
@@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult",
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
version_2_with_negative
,
null_score_diff_threshold
):
version_2_with_negative
,
null_score_diff_threshold
,
tokenizer
,
mode_type
=
'bert'
):
"""Write final predictions to the json file and log-odds of null if needed."""
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
...
...
@@ -576,6 +576,13 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text
=
" "
.
join
(
tok_tokens
)
# De-tokenize WordPieces that have been split off.
if
mode_type
==
'roberta'
:
tok_text
=
tokenizer
.
convert_tokens_to_string
(
tok_tokens
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
tok_text
=
" "
.
join
(
tok_text
.
strip
().
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
,
verbose_logging
,
None
)
else
:
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment