Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3ba6452
Commit
c3ba6452
authored
Nov 22, 2019
by
Lysandre
Committed by
LysandreJik
Nov 22, 2019
Browse files
Works for XLNet
parent
a5a8a617
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
72 deletions
+50
-72
examples/run_squad.py
examples/run_squad.py
+10
-28
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+40
-44
No files found.
examples/run_squad.py
View file @
c3ba6452
...
...
@@ -16,6 +16,7 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
import
argparse
import
logging
...
...
@@ -46,8 +47,7 @@ from transformers import (WEIGHTS_NAME, BertConfig,
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
,
read_squad_examples
as
sread_squad_examples
from
utils_squad
import
(
read_squad_examples
,
convert_examples_to_features
,
RawResult
,
write_predictions
,
from
utils_squad
import
(
RawResult
,
write_predictions
,
RawResultExtended
,
write_predictions_extended
)
# The follwing import is the official SQuAD evaluation script (2.0).
...
...
@@ -289,7 +289,6 @@ def evaluate(args, model, tokenizer, prefix=""):
results
=
evaluate_on_squad
(
evaluate_options
)
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
...
...
@@ -308,24 +307,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples
=
read_squad_examples
(
input_file
=
input_file
,
is_training
=
not
evaluate
,
version_2_with_negative
=
args
.
version_2_with_negative
)
examples
=
examples
[:
10
]
features
=
convert_examples_to_features
(
examples
=
examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
cls_token_segment_id
=
2
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
exampless
=
sread_squad_examples
(
input_file
=
input_file
,
is_training
=
not
evaluate
,
version_2_with_negative
=
args
.
version_2_with_negative
)
exampless
=
exampless
[:
10
]
features2
=
squad_convert_examples_to_features
(
examples
=
exampless
,
keep_n_examples
=
1000
processor
=
SquadV1Processor
()
values
=
processor
.
get_dev_examples
(
"examples/squad"
)
examples
=
values
[:
keep_n_examples
]
features
=
squad_convert_examples_to_features
(
examples
=
exampless
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
...
...
@@ -335,14 +321,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
pad_token_segment_id
=
3
if
args
.
model_type
in
[
'xlnet'
]
else
0
,
cls_token_at_end
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
,
sequence_a_is_doc
=
True
if
args
.
model_type
in
[
'xlnet'
]
else
False
)
print
(
features2
)
for
i
in
range
(
len
(
features
)):
assert
features
[
i
]
==
features2
[
i
]
print
(
"Equal"
)
print
(
"DONE"
)
import
sys
sys
.
exit
()
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
...
...
transformers/data/processors/squad.py
View file @
c3ba6452
...
...
@@ -83,6 +83,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
cls_token
=
tokenizer
.
cls_token
sep_token
=
tokenizer
.
sep_token
# Defining helper methods
unique_id
=
1000000000
...
...
@@ -136,24 +139,24 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
all_doc_tokens
,
truncated_query
if
not
sequence_a_is_doc
else
all_doc_tokens
,
all_doc_tokens
if
not
sequence_a_is_doc
else
truncated_query
,
max_length
=
max_seq_length
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
return_overflowing_tokens
=
True
,
truncation_strategy
=
'only_second'
truncation_strategy
=
'only_second'
if
not
sequence_a_is_doc
else
'only_first'
)
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
non_padded_ids
=
ids
[:
ids
.
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
ids
else
ids
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
0
+
i
]
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
not
sequence_a_is_doc
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
0
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
...
...
@@ -164,35 +167,40 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict
[
"length"
]
=
paragraph_len
spans
.
append
(
encoded_dict
)
print
(
"YESSIR"
,
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
),
"overflowing_tokens"
in
encoded_dict
)
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
)
and
"overflowing_tokens"
in
encoded_dict
:
overflowing_tokens
=
encoded_dict
[
'overflowing_tokens'
]
print
(
"OVERFLOW"
,
len
(
overflowing_tokens
))
# print("YESSIR", len(spans) * doc_stride < len(all_doc_tokens), "overflowing_tokens" in encoded_dict)
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
)
and
"overflowing_tokens"
in
encoded_dict
:
overflowing_tokens
=
encoded_dict
[
"overflowing_tokens"
]
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
overflowing_tokens
,
truncated_query
if
not
sequence_a_is_doc
else
overflowing_tokens
,
overflowing_tokens
if
not
sequence_a_is_doc
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
truncation_strategy
=
'only_second'
if
not
sequence_a_is_doc
else
'only_first'
)
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
# print("Ids computes; position of the first padding", ids.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in ids else None)
# print(encoded_dict["input_ids"].index(tokenizer.pad_token_id) if tokenizer.pad_token_id in encoded_dict["input_ids"] else None)
# print(len(spans) * doc_stride, len(all_doc_tokens))
# Length of the document without the query
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'input_ids'
]:
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'input_ids'
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
not
sequence_a_is_doc
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
...
...
@@ -202,23 +210,14 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"length"
]
=
paragraph_len
# split_token_index = doc_span.start + i
# token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
# is_max_context = _check_is_max_context(doc_spans, doc_span_index,
# split_token_index)
# token_is_max_context[len(tokens)] = is_max_context
# tokens.append(all_doc_tokens[split_token_index])
spans
.
append
(
encoded_dict
)
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
index
=
j
if
sequence_a_is_doc
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
print
(
"new span"
,
len
(
spans
))
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
...
...
@@ -227,17 +226,17 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
# Convert all SEP indices to '0' before inversion
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
0
p_mask
=
np
.
minimum
(
p_mask
,
1
)
if
not
sequence_a_is_doc
:
# Limit positive values to one
p_mask
=
1
-
p_mask
# Limit positive values to one
p_mask
=
1
-
np
.
minimum
(
p_mask
,
1
)
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
1
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
print
(
"new features length"
,
len
(
new_features
))
new_features
.
append
(
NewSquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
...
...
@@ -287,19 +286,15 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
print
(
"OLD DOC CREATION BEGIN"
,
start_offset
,
len
(
all_doc_tokens
))
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
# print("Start offset is", start_offset, len(all_doc_tokens), "length is", length)
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
print
(
"Done with this doc span, breaking out."
,
start_offset
,
length
)
break
print
(
"CHOOSING OFFSET"
,
length
,
doc_stride
)
start_offset
+=
min
(
length
,
doc_stride
)
print
(
"OLD DOC CREATION END"
,
start_offset
)
print
(
"old span"
,
len
(
doc_spans
))
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
...
...
@@ -382,7 +377,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask
.
append
(
0
if
mask_padding_with_zero
else
1
)
segment_ids
.
append
(
pad_token_segment_id
)
p_mask
.
append
(
1
)
print
(
"[OLD] Ids computed; position of the first padding"
,
input_ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
input_ids
else
None
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
...
...
@@ -440,7 +435,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# logger.info(
# "answer: %s" % (answer_text))
print
(
"features length"
,
len
(
features
))
features
.
append
(
SquadFeatures
(
unique_id
=
unique_id
,
...
...
@@ -464,10 +458,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert
len
(
features
)
==
len
(
new_features
)
for
i
in
range
(
len
(
features
)):
print
(
i
)
feature
,
new_feature
=
features
[
i
],
new_features
[
i
]
input_ids
=
feature
.
input_ids
input_ids
=
[
f
if
f
not
in
[
3
,
4
,
5
]
else
0
for
f
in
feature
.
input_ids
]
input_mask
=
feature
.
input_mask
segment_ids
=
feature
.
segment_ids
cls_index
=
feature
.
cls_index
...
...
@@ -478,7 +471,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens
=
feature
.
tokens
token_to_orig_map
=
feature
.
token_to_orig_map
new_input_ids
=
new_feature
.
input_ids
new_input_ids
=
[
f
if
f
not
in
[
3
,
4
,
5
]
else
0
for
f
in
new_feature
.
input_ids
]
new_input_mask
=
new_feature
.
attention_mask
new_segment_ids
=
new_feature
.
token_type_ids
new_cls_index
=
new_feature
.
cls_index
...
...
@@ -497,6 +490,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
assert
example_index
==
new_example_index
assert
paragraph_len
==
new_paragraph_len
assert
token_is_max_context
==
new_token_is_max_context
tokens
=
[
t
if
tokenizer
.
convert_tokens_to_ids
(
t
)
is
not
tokenizer
.
unk_token_id
else
tokenizer
.
unk_token
for
t
in
tokens
]
assert
tokens
==
new_tokens
assert
token_to_orig_map
==
new_token_to_orig_map
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment