Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4614729
Commit
d4614729
authored
Dec 13, 2019
by
Lysandre
Browse files
return for SQuAD [BLACKED]
parent
f24a228a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
110 deletions
+172
-110
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+1
-1
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+171
-109
No files found.
transformers/data/processors/glue.py
View file @
d4614729
transformers/data/processors/squad.py
View file @
d4614729
...
@@ -18,19 +18,20 @@ if is_tf_available():
...
@@ -18,19 +18,20 @@ if is_tf_available():
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
"""Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:
(
new_end
+
1
)])
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:
(
new_end
+
1
)])
if
text_span
==
tok_answer_text
:
if
text_span
==
tok_answer_text
:
return
(
new_start
,
new_end
)
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
"""Check if this is the 'max context' doc span for the token."""
best_score
=
None
best_score
=
None
...
@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
...
@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
"""Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1:
# if len(doc_spans) == 1:
...
@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
...
@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
def
_is_whitespace
(
c
):
def
_is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
True
return
False
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
def
squad_convert_examples_to_features
(
return_dataset
=
False
):
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
"""
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
@@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position
=
example
.
end_position
end_position
=
example
.
end_position
# If the answer cannot be found in the text, then skip this example.
# If the answer cannot be found in the text, then skip this example.
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
continue
tok_to_orig_index
=
[]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
all_doc_tokens
=
[]
...
@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index
.
append
(
i
)
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
all_doc_tokens
.
append
(
sub_token
)
if
is_training
and
not
example
.
is_impossible
:
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
...
@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
spans
=
[]
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
...
@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
pad_to_max_length
=
True
,
pad_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'
only_second
'
if
tokenizer
.
padding_side
==
"right"
else
'
only_first
'
truncation_strategy
=
"
only_second
"
if
tokenizer
.
padding_side
==
"right"
else
"
only_first
"
,
)
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'
input_ids
'
]:
if
tokenizer
.
pad_token_id
in
encoded_dict
[
"
input_ids
"
]:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
][:
encoded_dict
[
'
input_ids
'
].
index
(
tokenizer
.
pad_token_id
)]
non_padded_ids
=
encoded_dict
[
"
input_ids
"
][:
encoded_dict
[
"
input_ids
"
].
index
(
tokenizer
.
pad_token_id
)]
else
:
else
:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
]
non_padded_ids
=
encoded_dict
[
"
input_ids
"
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
...
@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
for
doc_span_index
in
range
(
len
(
spans
)):
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
index
=
(
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
)
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
for
span
in
spans
:
for
span
in
spans
:
# Identify the position of the CLS token
# Identify the position of the CLS token
cls_index
=
span
[
'
input_ids
'
].
index
(
tokenizer
.
cls_token_id
)
cls_index
=
span
[
"
input_ids
"
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'
token_type_ids
'
])
p_mask
=
np
.
array
(
span
[
"
token_type_ids
"
])
p_mask
=
np
.
minimum
(
p_mask
,
1
)
p_mask
=
np
.
minimum
(
p_mask
,
1
)
...
@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0'
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
p_mask
[
cls_index
]
=
0
span_is_impossible
=
example
.
is_impossible
span_is_impossible
=
example
.
is_impossible
start_position
=
0
start_position
=
0
end_position
=
0
end_position
=
0
...
@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position
=
tok_start_position
-
doc_start
+
doc_offset
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
features
.
append
(
features
.
append
(
SquadFeatures
(
SquadFeatures
(
span
[
'
input_ids
'
],
span
[
"
input_ids
"
],
span
[
'
attention_mask
'
],
span
[
"
attention_mask
"
],
span
[
'
token_type_ids
'
],
span
[
"
token_type_ids
"
],
cls_index
,
cls_index
,
p_mask
.
tolist
(),
p_mask
.
tolist
(),
example_index
=
example_index
,
example_index
=
example_index
,
unique_id
=
unique_id
,
unique_id
=
unique_id
,
paragraph_len
=
span
[
'
paragraph_len
'
],
paragraph_len
=
span
[
"
paragraph_len
"
],
token_is_max_context
=
span
[
"token_is_max_context"
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
end_position
=
end_position
,
))
)
)
unique_id
+=
1
unique_id
+=
1
if
return_dataset
==
'
pt
'
:
if
return_dataset
==
"
pt
"
:
if
not
is_torch_available
():
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
# Convert to Tensors and build dataset
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_
input
_mask
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_
attention
_mask
s
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_
segment
_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_
token_type
_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
if
not
is_training
:
if
not
is_training
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
dataset
=
TensorDataset
(
all_example_index
,
all_cls_index
,
all_p_mask
)
all_input_ids
,
all_attention_masks
,
all_token_type_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
else
:
else
:
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
dataset
=
TensorDataset
(
all_start_positions
,
all_end_positions
,
all_input_ids
,
all_cls_index
,
all_p_mask
)
all_attention_masks
,
all_token_type_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
,
)
return
features
,
dataset
return
features
,
dataset
elif
return_dataset
==
"tf"
:
if
not
is_tf_available
():
raise
ImportError
(
"TensorFlow must be installed to return a TensorFlow dataset."
)
def
gen
():
for
ex
in
features
:
yield
(
{
"input_ids"
:
ex
.
input_ids
,
"attention_mask"
:
ex
.
attention_mask
,
"token_type_ids"
:
ex
.
token_type_ids
,
},
{
"start_position"
:
ex
.
start_position
,
"end_position"
:
ex
.
end_position
,
"cls_index"
:
ex
.
cls_index
,
"p_mask"
:
ex
.
p_mask
,
}
)
return
tf
.
data
.
Dataset
.
from_generator
(
gen
,
(
{
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
,
"token_type_ids"
:
tf
.
int32
},
{
"start_position"
:
tf
.
int64
,
"end_position"
:
tf
.
int64
,
"cls_index"
:
tf
.
int64
,
"p_mask"
:
tf
.
int32
},
),
(
{
"input_ids"
:
tf
.
TensorShape
([
None
]),
"attention_mask"
:
tf
.
TensorShape
([
None
]),
"token_type_ids"
:
tf
.
TensorShape
([
None
]),
},
{
"start_position"
:
tf
.
TensorShape
([]),
"end_position"
:
tf
.
TensorShape
([]),
"cls_index"
:
tf
.
TensorShape
([]),
"p_mask"
:
tf
.
TensorShape
([
None
]),
},
),
)
return
features
return
features
...
@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
...
@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
Processor for the SQuAD data set.
Processor for the SQuAD data set.
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
"""
"""
train_file
=
None
train_file
=
None
dev_file
=
None
dev_file
=
None
def
_get_example_from_tensor_dict
(
self
,
tensor_dict
,
evaluate
=
False
):
def
_get_example_from_tensor_dict
(
self
,
tensor_dict
,
evaluate
=
False
):
if
not
evaluate
:
if
not
evaluate
:
answer
=
tensor_dict
[
'
answers
'
][
'
text
'
][
0
].
numpy
().
decode
(
'
utf-8
'
)
answer
=
tensor_dict
[
"
answers
"
][
"
text
"
][
0
].
numpy
().
decode
(
"
utf-8
"
)
answer_start
=
tensor_dict
[
'
answers
'
][
'
answer_start
'
][
0
].
numpy
()
answer_start
=
tensor_dict
[
"
answers
"
][
"
answer_start
"
][
0
].
numpy
()
answers
=
[]
answers
=
[]
else
:
else
:
answers
=
[
{
answers
=
[
"answer_start"
:
start
.
numpy
(),
{
"answer_start"
:
start
.
numpy
(),
"text"
:
text
.
numpy
().
decode
(
"utf-8"
)}
"text"
:
text
.
numpy
().
decode
(
'utf-8'
)
for
start
,
text
in
zip
(
tensor_dict
[
"answers"
][
"answer_start"
],
tensor_dict
[
"answers"
][
"text"
]
)
}
for
start
,
text
in
zip
(
tensor_dict
[
'answers'
][
"answer_start"
],
tensor_dict
[
'answers'
][
"text"
])
]
]
answer
=
None
answer
=
None
answer_start
=
None
answer_start
=
None
return
SquadExample
(
return
SquadExample
(
qas_id
=
tensor_dict
[
'
id
'
].
numpy
().
decode
(
"utf-8"
),
qas_id
=
tensor_dict
[
"
id
"
].
numpy
().
decode
(
"utf-8"
),
question_text
=
tensor_dict
[
'
question
'
].
numpy
().
decode
(
'
utf-8
'
),
question_text
=
tensor_dict
[
"
question
"
].
numpy
().
decode
(
"
utf-8
"
),
context_text
=
tensor_dict
[
'
context
'
].
numpy
().
decode
(
'
utf-8
'
),
context_text
=
tensor_dict
[
"
context
"
].
numpy
().
decode
(
"
utf-8
"
),
answer_text
=
answer
,
answer_text
=
answer
,
start_position_character
=
answer_start
,
start_position_character
=
answer_start
,
title
=
tensor_dict
[
'
title
'
].
numpy
().
decode
(
'
utf-8
'
),
title
=
tensor_dict
[
"
title
"
].
numpy
().
decode
(
"
utf-8
"
),
answers
=
answers
answers
=
answers
,
)
)
def
get_examples_from_dataset
(
self
,
dataset
,
evaluate
=
False
):
def
get_examples_from_dataset
(
self
,
dataset
,
evaluate
=
False
):
...
@@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor):
...
@@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor):
if
self
.
train_file
is
None
:
if
self
.
train_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
train_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
train_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
"utf-8"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"train"
)
return
self
.
_create_examples
(
input_data
,
"train"
)
...
@@ -398,7 +455,9 @@ class SquadProcessor(DataProcessor):
...
@@ -398,7 +455,9 @@ class SquadProcessor(DataProcessor):
if
self
.
dev_file
is
None
:
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
"utf-8"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
)
return
self
.
_create_examples
(
input_data
,
"dev"
)
...
@@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor):
...
@@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor):
is_training
=
set_type
==
"train"
is_training
=
set_type
==
"train"
examples
=
[]
examples
=
[]
for
entry
in
tqdm
(
input_data
):
for
entry
in
tqdm
(
input_data
):
title
=
entry
[
'
title
'
]
title
=
entry
[
"
title
"
]
for
paragraph
in
entry
[
"paragraphs"
]:
for
paragraph
in
entry
[
"paragraphs"
]:
context_text
=
paragraph
[
"context"
]
context_text
=
paragraph
[
"context"
]
for
qa
in
paragraph
[
"qas"
]:
for
qa
in
paragraph
[
"qas"
]:
...
@@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor):
...
@@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor):
if
not
is_impossible
:
if
not
is_impossible
:
if
is_training
:
if
is_training
:
answer
=
qa
[
"answers"
][
0
]
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'
text
'
]
answer_text
=
answer
[
"
text
"
]
start_position_character
=
answer
[
'
answer_start
'
]
start_position_character
=
answer
[
"
answer_start
"
]
else
:
else
:
answers
=
qa
[
"answers"
]
answers
=
qa
[
"answers"
]
...
@@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor):
...
@@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor):
start_position_character
=
start_position_character
,
start_position_character
=
start_position_character
,
title
=
title
,
title
=
title
,
is_impossible
=
is_impossible
,
is_impossible
=
is_impossible
,
answers
=
answers
answers
=
answers
,
)
)
examples
.
append
(
example
)
examples
.
append
(
example
)
return
examples
return
examples
class
SquadV1Processor
(
SquadProcessor
):
class
SquadV1Processor
(
SquadProcessor
):
train_file
=
"train-v1.1.json"
train_file
=
"train-v1.1.json"
dev_file
=
"dev-v1.1.json"
dev_file
=
"dev-v1.1.json"
...
@@ -468,7 +528,8 @@ class SquadExample(object):
...
@@ -468,7 +528,8 @@ class SquadExample(object):
is_impossible: False by default, set to True if the example has no possible answer.
is_impossible: False by default, set to True if the example has no possible answer.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
qas_id
,
qas_id
,
question_text
,
question_text
,
context_text
,
context_text
,
...
@@ -476,7 +537,8 @@ class SquadExample(object):
...
@@ -476,7 +537,8 @@ class SquadExample(object):
start_position_character
,
start_position_character
,
title
,
title
,
answers
=
[],
answers
=
[],
is_impossible
=
False
):
is_impossible
=
False
,
):
self
.
qas_id
=
qas_id
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
question_text
=
question_text
self
.
context_text
=
context_text
self
.
context_text
=
context_text
...
@@ -537,22 +599,21 @@ class SquadFeatures(object):
...
@@ -537,22 +599,21 @@ class SquadFeatures(object):
end_position: end of the answer token index
end_position: end of the answer token index
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
input_ids
,
input_ids
,
attention_mask
,
attention_mask
,
token_type_ids
,
token_type_ids
,
cls_index
,
cls_index
,
p_mask
,
p_mask
,
example_index
,
example_index
,
unique_id
,
unique_id
,
paragraph_len
,
paragraph_len
,
token_is_max_context
,
token_is_max_context
,
tokens
,
tokens
,
token_to_orig_map
,
token_to_orig_map
,
start_position
,
start_position
,
end_position
end_position
,
):
):
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
attention_mask
=
attention_mask
...
@@ -580,6 +641,7 @@ class SquadResult(object):
...
@@ -580,6 +641,7 @@ class SquadResult(object):
start_logits: The logits corresponding to the start of the answer
start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer
end_logits: The logits corresponding to the end of the answer
"""
"""
def
__init__
(
self
,
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
None
,
end_top_index
=
None
,
cls_logits
=
None
):
def
__init__
(
self
,
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
None
,
end_top_index
=
None
,
cls_logits
=
None
):
self
.
start_logits
=
start_logits
self
.
start_logits
=
start_logits
self
.
end_logits
=
end_logits
self
.
end_logits
=
end_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment