Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3ba6452
Commit
c3ba6452
authored
Nov 22, 2019
by
Lysandre
Committed by
LysandreJik
Nov 22, 2019
Browse files
Works for XLNet
parent
a5a8a617
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
72 deletions
+50
-72
examples/run_squad.py
examples/run_squad.py
+10
-28
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+40
-44
No files found.
examples/run_squad.py
View file @
c3ba6452
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
import
argparse
import
argparse
import
logging
import
logging
...
@@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
...
@@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
,
read_squad_examples
as
sread_squad_examples
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
,
read_squad_examples
as
sread_squad_examples
from
utils_squad
import
(
read_squad_examples
,
convert_examples_to_features
,
from
utils_squad
import
(
RawResult
,
write_predictions
,
RawResult
,
write_predictions
,
RawResultExtended
,
write_predictions_extended
)
RawResultExtended
,
write_predictions_extended
)
# The follwing import is the official SQuAD evaluation script (2.0).
# The follwing import is the official SQuAD evaluation script (2.0).
...
@@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
results
=
evaluate_on_squad
(
evaluate_options
)
results
=
evaluate_on_squad
(
evaluate_options
)
return
results
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
...
@@ -308,24 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -308,24 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples
=
read_squad_examples
(
input_file
=
input_file
,
examples
=
read_squad_examples
(
input_file
=
input_file
,
is_training
=
not
evaluate
,
is_training
=
not
evaluate
,
version_2_with_negative
=
args
.
version_2_with_negative
)
version_2_with_negative
=
args
.
version_2_with_negative
)
keep_n_examples
=
1000
examples
=
examples
[:
10
]
processor
=
SquadV1Processor
()
features
=
convert_examples_to_features
(
examples
=
examples
,
values
=
processor
.
get_dev_examples
(
"examples/squad"
)
tokenizer
=
tokenizer
,
examples
=
values
[:
keep_n_examples
]
max_seq_length
=
args
.
max_seq_length
,
features
=
squad_convert_examples_to_features
(
examples
=
exampless
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
cls_token_segment_id
=
2
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
exampless
=
sread_squad_examples
(
input_file
=
input_file
,
is_training
=
not
evaluate
,
version_2_with_negative
=
args
.
version_2_with_negative
)
exampless
=
exampless
[:
10
]
features2
=
squad_convert_examples_to_features
(
examples
=
exampless
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
doc_stride
=
args
.
doc_stride
,
...
@@ -335,15 +321,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -335,15 +321,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
print
(
features2
)
for
i
in
range
(
len
(
features
)):
assert
features
[
i
]
==
features2
[
i
]
print
(
"Equal"
)
print
(
"DONE"
)
print
(
"DONE"
)
import
sys
sys
.
exit
()
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
...
...
transformers/data/processors/squad.py
View file @
c3ba6452
...
@@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_a_is_doc
=
False
):
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
cls_token
=
tokenizer
.
cls_token
sep_token
=
tokenizer
.
sep_token
# Defining helper methods
# Defining helper methods
unique_id
=
1000000000
unique_id
=
1000000000
...
@@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
encoded_dict
=
tokenizer
.
encode_plus
(
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
truncated_query
if
not
sequence_a_is_doc
else
all_doc_tokens
,
all_doc_tokens
,
all_doc_tokens
if
not
sequence_a_is_doc
else
truncated_query
,
max_length
=
max_seq_length
,
max_length
=
max_seq_length
,
padding_strategy
=
'right'
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
truncation_strategy
=
'only_second'
truncation_strategy
=
'only_second'
if
not
sequence_a_is_doc
else
'only_first'
)
)
ids
=
encoded_dict
[
'input_ids'
]
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
non_padded_ids
=
ids
[:
ids
.
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
ids
else
ids
non_padded_ids
=
ids
[:
ids
.
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
ids
else
ids
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
0
+
i
]
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
not
sequence_a_is_doc
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
0
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"tokens"
]
=
tokens
...
@@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict
[
"length"
]
=
paragraph_len
encoded_dict
[
"length"
]
=
paragraph_len
spans
.
append
(
encoded_dict
)
spans
.
append
(
encoded_dict
)
print
(
"YESSIR"
,
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
),
"overflowing_tokens"
in
encoded_dict
)
# print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
)
and
"overflowing_tokens"
in
encoded_dict
:
overflowing_tokens
=
encoded_dict
[
'overflowing_tokens'
]
print
(
"OVERFLOW"
,
len
(
overflowing_tokens
))
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
)
and
"overflowing_tokens"
in
encoded_dict
:
overflowing_tokens
=
encoded_dict
[
"overflowing_tokens"
]
encoded_dict
=
tokenizer
.
encode_plus
(
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
truncated_query
if
not
sequence_a_is_doc
else
overflowing_tokens
,
overflowing_tokens
,
overflowing_tokens
if
not
sequence_a_is_doc
else
truncated_query
,
max_length
=
max_seq_length
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
padding_strategy
=
'right'
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
truncation_strategy
=
'only_second'
if
not
sequence_a_is_doc
else
'only_first'
)
)
ids
=
encoded_dict
[
'input_ids'
]
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
# print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
# print(encoded_dict["input_ids"].index(tokenizer.pad_token_id) if tokenizer.pad_token_id in encoded_dict["input_ids"] else None)
# print(len(spans) * doc_stride, len(all_doc_tokens))
# Length of the document without the query
# Length of the document without the query
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'input_ids'
]:
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'input_ids'
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
not
sequence_a_is_doc
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"tokens"
]
=
tokens
...
@@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"length"
]
=
paragraph_len
encoded_dict
[
"length"
]
=
paragraph_len
# split_token_index = doc_span.start + i
# token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
# is_max_context = _check_is_max_context(doc_spans, doc_span_index,
# split_token_index)
# token_is_max_context[len(tokens)] = is_max_context
# tokens.append(all_doc_tokens[split_token_index])
spans
.
append
(
encoded_dict
)
spans
.
append
(
encoded_dict
)
for
doc_span_index
in
range
(
len
(
spans
)):
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
index
=
j
if
sequence_a_is_doc
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
print
(
"new span"
,
len
(
spans
))
for
span
in
spans
:
for
span
in
spans
:
# Identify the position of the CLS token
# Identify the position of the CLS token
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
...
@@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Original TF implem also keep the classification token (set to 0) (not sure why...)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
# Convert all SEP indices to '0' before inversion
p_mask
=
np
.
minimum
(
p_mask
,
1
)
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
0
if
not
sequence_a_is_doc
:
# Limit positive values to one
# Limit positive values to one
p_mask
=
1
-
np
.
minimum
(
p_mask
,
1
)
p_mask
=
1
-
p_mask
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
1
# Set the CLS index to '0'
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
p_mask
[
cls_index
]
=
0
print
(
"new features length"
,
len
(
new_features
))
new_features
.
append
(
NewSquadFeatures
(
new_features
.
append
(
NewSquadFeatures
(
span
[
'input_ids'
],
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'attention_mask'
],
...
@@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_spans
=
[]
doc_spans
=
[]
start_offset
=
0
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
while
start_offset
<
len
(
all_doc_tokens
):
print
(
"OLD DOC CREATION BEGIN"
,
start_offset
,
len
(
all_doc_tokens
))
length
=
len
(
all_doc_tokens
)
-
start_offset
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
length
=
max_tokens_for_doc
# print("Start offset is", start_offset, len(all_doc_tokens), "length is", length)
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
if
start_offset
+
length
==
len
(
all_doc_tokens
):
print
(
"Done with this doc span, breaking out."
,
start_offset
,
length
)
break
break
print
(
"CHOOSING OFFSET"
,
length
,
doc_stride
)
start_offset
+=
min
(
length
,
doc_stride
)
start_offset
+=
min
(
length
,
doc_stride
)
print
(
"OLD DOC CREATION END"
,
start_offset
)
print
(
"old span"
,
len
(
doc_spans
))
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
tokens
=
[]
token_to_orig_map
=
{}
token_to_orig_map
=
{}
...
@@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask
.
append
(
0
if
mask_padding_with_zero
else
1
)
input_mask
.
append
(
0
if
mask_padding_with_zero
else
1
)
segment_ids
.
append
(
pad_token_segment_id
)
segment_ids
.
append
(
pad_token_segment_id
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
print
(
"[OLD] Ids computed; position of the first padding"
,
input_ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
input_ids
else
None
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
...
@@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# logger.info(
# logger.info(
# "answer: %s" % (answer_text))
# "answer: %s" % (answer_text))
print
(
"features length"
,
len
(
features
))
features
.
append
(
features
.
append
(
SquadFeatures
(
SquadFeatures
(
unique_id
=
unique_id
,
unique_id
=
unique_id
,
...
@@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert
len
(
features
)
==
len
(
new_features
)
assert
len
(
features
)
==
len
(
new_features
)
for
i
in
range
(
len
(
features
)):
for
i
in
range
(
len
(
features
)):
print
(
i
)
feature
,
new_feature
=
features
[
i
],
new_features
[
i
]
feature
,
new_feature
=
features
[
i
],
new_features
[
i
]
input_ids
=
feature
.
input_ids
input_ids
=
[
f
if
f
not
in
[
3
,
4
,
5
]
else
0
for
f
in
feature
.
input_ids
]
input_mask
=
feature
.
input_mask
input_mask
=
feature
.
input_mask
segment_ids
=
feature
.
segment_ids
segment_ids
=
feature
.
segment_ids
cls_index
=
feature
.
cls_index
cls_index
=
feature
.
cls_index
...
@@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens
=
feature
.
tokens
tokens
=
feature
.
tokens
token_to_orig_map
=
feature
.
token_to_orig_map
token_to_orig_map
=
feature
.
token_to_orig_map
new_input_ids
=
new_feature
.
input_ids
new_input_ids
=
[
f
if
f
not
in
[
3
,
4
,
5
]
else
0
for
f
in
new_feature
.
input_ids
]
new_input_mask
=
new_feature
.
attention_mask
new_input_mask
=
new_feature
.
attention_mask
new_segment_ids
=
new_feature
.
token_type_ids
new_segment_ids
=
new_feature
.
token_type_ids
new_cls_index
=
new_feature
.
cls_index
new_cls_index
=
new_feature
.
cls_index
...
@@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert
example_index
==
new_example_index
assert
example_index
==
new_example_index
assert
paragraph_len
==
new_paragraph_len
assert
paragraph_len
==
new_paragraph_len
assert
token_is_max_context
==
new_token_is_max_context
assert
token_is_max_context
==
new_token_is_max_context
tokens
=
[
t
if
tokenizer
.
convert_tokens_to_ids
(
t
)
is
not
tokenizer
.
unk_token_id
else
tokenizer
.
unk_token
for
t
in
tokens
]
assert
tokens
==
new_tokens
assert
tokens
==
new_tokens
assert
token_to_orig_map
==
new_token_to_orig_map
assert
token_to_orig_map
==
new_token_to_orig_map
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment