Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5a8a617
You need to sign in or sign up before continuing.
Commit
a5a8a617
authored
Nov 21, 2019
by
LysandreJik
Browse files
Works for BERT
parent
a7dafe2f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
433 additions
and
76 deletions
+433
-76
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+433
-76
No files found.
transformers/data/processors/squad.py
View file @
a5a8a617
...
...
@@ -3,6 +3,7 @@ import collections
import
logging
import
os
import
json
import
numpy
as
np
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
...
...
@@ -13,19 +14,7 @@ if is_tf_available():
logger
=
logging
.
getLogger
(
__name__
)
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
,
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
# Defining helper methods
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
...
...
@@ -37,7 +26,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
best_score
=
None
best_span_index
=
None
...
...
@@ -56,25 +46,221 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return
cur_span_index
==
best_span_index
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1:
# return True
best_score
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
end
=
doc_span
[
"start"
]
+
doc_span
[
"length"
]
-
1
if
position
<
doc_span
[
"start"
]:
continue
if
position
>
end
:
continue
num_left_context
=
position
-
doc_span
[
"start"
]
num_right_context
=
end
-
position
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
[
"length"
]
if
best_score
is
None
or
score
>
best_score
:
best_score
=
score
best_span_index
=
span_index
return
cur_span_index
==
best_span_index
def
_is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
cls_token_at_end
=
True
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
,
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
# Defining helper methods
unique_id
=
1000000000
features
=
[]
new_features
=
[]
for
(
example_index
,
example
)
in
enumerate
(
tqdm
(
examples
)):
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
doc_tokens
=
[]
char_to_word_offset
=
[]
prev_is_whitespace
=
True
# Split on whitespace so that different tokens may be attributed to their original position.
for
c
in
example
.
context_text
:
if
_is_whitespace
(
c
):
prev_is_whitespace
=
True
else
:
if
prev_is_whitespace
:
doc_tokens
.
append
(
c
)
else
:
doc_tokens
[
-
1
]
+=
c
prev_is_whitespace
=
False
char_to_word_offset
.
append
(
len
(
doc_tokens
)
-
1
)
if
is_training
:
# Get start and end position
answer_length
=
len
(
example
.
answer_text
)
start_position
=
char_to_word_offset
[
example
.
start_position
]
end_position
=
char_to_word_offset
[
example
.
start_position
+
answer_length
-
1
]
# If the answer cannot be found in the text, then skip this example.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
for
(
i
,
token
)
in
enumerate
(
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
all_doc_tokens
,
max_length
=
max_seq_length
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
return_overflowing_tokens
=
True
,
truncation_strategy
=
'only_second'
)
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
non_padded_ids
=
ids
[:
ids
.
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
ids
else
ids
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
0
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"token_to_orig_map"
]
=
token_to_orig_map
encoded_dict
[
"truncated_query_with_special_tokens_length"
]
=
len
(
truncated_query
)
+
sequence_added_tokens
encoded_dict
[
"token_is_max_context"
]
=
{}
encoded_dict
[
"start"
]
=
0
encoded_dict
[
"length"
]
=
paragraph_len
spans
.
append
(
encoded_dict
)
print
(
"YESSIR"
,
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
),
"overflowing_tokens"
in
encoded_dict
)
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
)
and
"overflowing_tokens"
in
encoded_dict
:
overflowing_tokens
=
encoded_dict
[
'overflowing_tokens'
]
print
(
"OVERFLOW"
,
len
(
overflowing_tokens
))
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
,
overflowing_tokens
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
padding_strategy
=
'right'
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
)
ids
=
encoded_dict
[
'input_ids'
]
print
(
"Ids computes; position of the first padding"
,
ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
ids
else
None
)
# Length of the document without the query
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
token_to_orig_map
[
len
(
truncated_query
)
+
sequence_added_tokens
+
i
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"token_to_orig_map"
]
=
token_to_orig_map
encoded_dict
[
"truncated_query_with_special_tokens_length"
]
=
len
(
truncated_query
)
+
sequence_added_tokens
encoded_dict
[
"token_is_max_context"
]
=
{}
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"length"
]
=
paragraph_len
# split_token_index = doc_span.start + i
# token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
# is_max_context = _check_is_max_context(doc_spans, doc_span_index,
# split_token_index)
# token_is_max_context[len(tokens)] = is_max_context
# tokens.append(all_doc_tokens[split_token_index])
spans
.
append
(
encoded_dict
)
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
print
(
"new span"
,
len
(
spans
))
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
# Convert all SEP indices to '0' before inversion
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
0
# Limit positive values to one
p_mask
=
1
-
np
.
minimum
(
p_mask
,
1
)
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
print
(
"new features length"
,
len
(
new_features
))
new_features
.
append
(
NewSquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'token_type_ids'
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
example_index
,
unique_id
=
unique_id
,
paragraph_len
=
span
[
'paragraph_len'
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
]
))
unique_id
+=
1
# tokenize ...
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
tok_start_position
=
None
tok_end_position
=
None
if
is_training
and
example
.
is_impossible
:
...
...
@@ -82,7 +268,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_end_position
=
-
1
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
if
example
.
end_position
<
len
(
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
...
...
@@ -101,14 +287,19 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
print
(
"OLD DOC CREATION BEGIN"
,
start_offset
,
len
(
all_doc_tokens
))
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
print
(
"Done with this doc span, breaking out."
,
start_offset
,
length
)
break
print
(
"CHOOSING OFFSET"
,
length
,
doc_stride
)
start_offset
+=
min
(
length
,
doc_stride
)
print
(
"OLD DOC CREATION END"
,
start_offset
)
print
(
"old span"
,
len
(
doc_spans
))
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
...
...
@@ -183,18 +374,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
pad_token
)
input_mask
.
append
(
0
if
mask_padding_with_zero
else
1
)
segment_ids
.
append
(
pad_token_segment_id
)
p_mask
.
append
(
1
)
print
(
"[OLD] Ids computed; position of the first padding"
,
input_ids
.
index
(
tokenizer
.
pad_token_id
)
if
tokenizer
.
pad_token_id
in
input_ids
else
None
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
span_is_impossible
=
example
.
is_impossible
span_is_impossible
=
example
.
is_impossible
if
hasattr
(
example
,
"is_impossible"
)
else
False
start_position
=
None
end_position
=
None
if
is_training
and
not
span_is_impossible
:
...
...
@@ -222,31 +415,32 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position
=
cls_index
end_position
=
cls_index
if
example_index
<
20
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"unique_id: %s"
%
(
unique_id
))
logger
.
info
(
"example_index: %s"
%
(
example_index
))
logger
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
logger
.
info
(
"tokens: %s"
%
" "
.
join
(
tokens
))
logger
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
([
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
token_to_orig_map
.
items
()]))
logger
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
token_is_max_context
.
items
()
]))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
and
span_is_impossible
:
logger
.
info
(
"impossible example"
)
if
is_training
and
not
span_is_impossible
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
logger
.
info
(
"start_position: %d"
%
(
start_position
))
logger
.
info
(
"end_position: %d"
%
(
end_position
))
logger
.
info
(
"answer: %s"
%
(
answer_text
))
# if example_index < 20:
# logger.info("*** Example ***")
# logger.info("unique_id: %s" % (unique_id))
# logger.info("example_index: %s" % (example_index))
# logger.info("doc_span_index: %s" % (doc_span_index))
# logger.info("tokens: %s" % str(tokens))
# logger.info("token_to_orig_map: %s" % " ".join([
# "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
# logger.info("token_is_max_context: %s" % " ".join([
# "%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
# ]))
# logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
# logger.info(
# "input_mask: %s" % " ".join([str(x) for x in input_mask]))
# logger.info(
# "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
# if is_training and span_is_impossible:
# logger.info("impossible example")
# if is_training and not span_is_impossible:
# answer_text = " ".join(tokens[start_position:(end_position + 1)])
# logger.info("start_position: %d" % (start_position))
# logger.info("end_position: %d" % (end_position))
# logger.info(
# "answer: %s" % (answer_text))
print
(
"features length"
,
len
(
features
))
features
.
append
(
SquadFeatures
(
unique_id
=
unique_id
,
...
...
@@ -266,7 +460,48 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
is_impossible
=
span_is_impossible
))
unique_id
+=
1
return
features
assert
len
(
features
)
==
len
(
new_features
)
assert
len
(
features
)
==
len
(
new_features
)
for
i
in
range
(
len
(
features
)):
print
(
i
)
feature
,
new_feature
=
features
[
i
],
new_features
[
i
]
input_ids
=
feature
.
input_ids
input_mask
=
feature
.
input_mask
segment_ids
=
feature
.
segment_ids
cls_index
=
feature
.
cls_index
p_mask
=
feature
.
p_mask
example_index
=
feature
.
example_index
paragraph_len
=
feature
.
paragraph_len
token_is_max_context
=
feature
.
token_is_max_context
tokens
=
feature
.
tokens
token_to_orig_map
=
feature
.
token_to_orig_map
new_input_ids
=
new_feature
.
input_ids
new_input_mask
=
new_feature
.
attention_mask
new_segment_ids
=
new_feature
.
token_type_ids
new_cls_index
=
new_feature
.
cls_index
new_p_mask
=
new_feature
.
p_mask
new_example_index
=
new_feature
.
example_index
new_paragraph_len
=
new_feature
.
paragraph_len
new_token_is_max_context
=
new_feature
.
token_is_max_context
new_tokens
=
new_feature
.
tokens
new_token_to_orig_map
=
new_feature
.
token_to_orig_map
assert
input_ids
==
new_input_ids
assert
input_mask
==
new_input_mask
assert
segment_ids
==
new_segment_ids
assert
cls_index
==
new_cls_index
assert
p_mask
==
new_p_mask
assert
example_index
==
new_example_index
assert
paragraph_len
==
new_paragraph_len
assert
token_is_max_context
==
new_token_is_max_context
assert
tokens
==
new_tokens
assert
token_to_orig_map
==
new_token_to_orig_map
return
new_features
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
...
...
@@ -347,6 +582,124 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
return
examples
class
SquadV1Processor
(
DataProcessor
):
"""Processor for the SQuAD data set."""
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
return
NewSquadExample
(
tensor_dict
[
'id'
].
numpy
(),
tensor_dict
[
'question'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'context'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'answers'
][
'text'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'answers'
][
'answers_start'
].
numpy
().
decode
(
'utf-8'
),
tensor_dict
[
'title'
].
numpy
().
decode
(
'utf-8'
)
)
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
with
open
(
os
.
path
.
join
(
data_dir
,
"train-v1.1.json"
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
with
open
(
os
.
path
.
join
(
data_dir
,
"dev-v1.1.json"
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
input_data
,
set_type
):
"""Creates examples for the training and dev sets."""
is_training
=
set_type
==
"train"
examples
=
[]
for
entry
in
input_data
:
title
=
entry
[
'title'
]
for
paragraph
in
entry
[
"paragraphs"
]:
context_text
=
paragraph
[
"context"
]
for
qa
in
paragraph
[
"qas"
]:
qas_id
=
qa
[
"id"
]
question_text
=
qa
[
"question"
]
start_position
=
None
answer_text
=
None
if
is_training
:
if
(
len
(
qa
[
"answers"
])
!=
1
):
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
start_position
=
answer
[
'answer_start'
]
example
=
NewSquadExample
(
qas_id
=
qas_id
,
question_text
=
question_text
,
context_text
=
context_text
,
answer_text
=
answer_text
,
start_position
=
start_position
,
title
=
title
)
examples
.
append
(
example
)
return
examples
class
NewSquadExample
(
object
):
"""
A single training/test example for the Squad dataset, as loaded from disk.
"""
def
__init__
(
self
,
qas_id
,
question_text
,
context_text
,
answer_text
,
start_position
,
title
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
context_text
=
context_text
self
.
answer_text
=
answer_text
self
.
start_position
=
start_position
self
.
title
=
title
class
NewSquadFeatures
(
object
):
"""
Single squad example features to be fed to a model.
Those features are model-specific.
"""
def
__init__
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
cls_index
,
p_mask
,
example_index
,
unique_id
,
paragraph_len
,
token_is_max_context
,
tokens
,
token_to_orig_map
):
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
token_type_ids
=
token_type_ids
self
.
cls_index
=
cls_index
self
.
p_mask
=
p_mask
self
.
example_index
=
example_index
self
.
unique_id
=
unique_id
self
.
paragraph_len
=
paragraph_len
self
.
token_is_max_context
=
token_is_max_context
self
.
tokens
=
tokens
self
.
token_to_orig_map
=
token_to_orig_map
class
SquadExample
(
object
):
"""
A single training/test example for the Squad dataset.
...
...
@@ -423,18 +776,22 @@ class SquadFeatures(object):
self
.
is_impossible
=
is_impossible
def
__eq__
(
self
,
other
):
return
self
.
cls_index
==
other
.
cls_index
and
\
self
.
doc_span_index
==
other
.
doc_span_index
and
\
self
.
end_position
==
other
.
end_position
and
\
self
.
example_index
==
other
.
example_index
and
\
print
(
self
.
example_index
==
other
.
example_index
)
print
(
self
.
input_ids
==
other
.
input_ids
)
print
(
self
.
input_mask
==
other
.
attention_mask
)
print
(
self
.
p_mask
==
other
.
p_mask
)
print
(
self
.
paragraph_len
==
other
.
paragraph_len
)
print
(
self
.
segment_ids
==
other
.
token_type_ids
)
print
(
self
.
token_is_max_context
==
other
.
token_is_max_context
)
print
(
self
.
token_to_orig_map
==
other
.
token_to_orig_map
)
print
(
self
.
tokens
==
other
.
tokens
)
return
self
.
example_index
==
other
.
example_index
and
\
self
.
input_ids
==
other
.
input_ids
and
\
self
.
input_mask
==
other
.
input_mask
and
\
self
.
is_impossible
==
other
.
is_impossible
and
\
self
.
input_mask
==
other
.
attention_mask
and
\
self
.
p_mask
==
other
.
p_mask
and
\
self
.
paragraph_len
==
other
.
paragraph_len
and
\
self
.
segment_ids
==
other
.
segment_ids
and
\
self
.
start_position
==
other
.
start_position
and
\
self
.
segment_ids
==
other
.
token_type_ids
and
\
self
.
token_is_max_context
==
other
.
token_is_max_context
and
\
self
.
token_to_orig_map
==
other
.
token_to_orig_map
and
\
self
.
tokens
==
other
.
tokens
and
\
self
.
unique_id
==
other
.
unique_id
\ No newline at end of file
self
.
tokens
==
other
.
tokens
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment