Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0669c1fc
Commit
0669c1fc
authored
Nov 25, 2019
by
Lysandre
Browse files
SQuAD v2 BERT + XLNet
parent
e0e55bc5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
94 deletions
+92
-94
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/data/__init__.py
transformers/data/__init__.py
+1
-1
transformers/data/processors/__init__.py
transformers/data/processors/__init__.py
+1
-1
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+89
-91
No files found.
transformers/__init__.py
View file @
0669c1fc
...
@@ -27,7 +27,7 @@ from .data import (is_sklearn_available,
...
@@ -27,7 +27,7 @@ from .data import (is_sklearn_available,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_processors
,
glue_tasks_num_labels
,
glue_processors
,
glue_tasks_num_labels
,
squad_convert_examples_to_features
,
SquadFeatures
,
squad_convert_examples_to_features
,
SquadFeatures
,
SquadExample
,
read_squad_examples
)
SquadExample
)
if
is_sklearn_available
():
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
from
.data
import
glue_compute_metrics
...
...
transformers/data/__init__.py
View file @
0669c1fc
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
,
SquadFeatures
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
,
SquadFeatures
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
squad_convert_examples_to_features
,
SquadExample
,
read_squad_examples
from
.processors
import
squad_convert_examples_to_features
,
SquadExample
from
.metrics
import
is_sklearn_available
from
.metrics
import
is_sklearn_available
if
is_sklearn_available
():
if
is_sklearn_available
():
...
...
transformers/data/processors/__init__.py
View file @
0669c1fc
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.squad
import
squad_convert_examples_to_features
,
SquadFeatures
,
SquadExample
,
read_squad_examples
from
.squad
import
squad_convert_examples_to_features
,
SquadFeatures
,
SquadExample
transformers/data/processors/squad.py
View file @
0669c1fc
...
@@ -46,7 +46,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
...
@@ -46,7 +46,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
"""Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1:
# if len(doc_spans) == 1:
...
@@ -92,7 +91,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -92,7 +91,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
features
=
[]
features
=
[]
new_features
=
[]
new_features
=
[]
for
(
example_index
,
example
)
in
enumerate
(
tqdm
(
examples
)):
for
(
example_index
,
example
)
in
enumerate
(
tqdm
(
examples
)):
if
is_training
:
if
is_training
and
not
example
.
is_impossible
:
# Get start and end position
# Get start and end position
answer_length
=
len
(
example
.
answer_text
)
answer_length
=
len
(
example
.
answer_text
)
start_position
=
example
.
start_position
start_position
=
example
.
start_position
...
@@ -105,6 +104,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -105,6 +104,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
continue
tok_to_orig_index
=
[]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
all_doc_tokens
=
[]
...
@@ -115,6 +115,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -115,6 +115,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index
.
append
(
i
)
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
all_doc_tokens
.
append
(
sub_token
)
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
answer_text
)
spans
=
[]
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
...
@@ -187,6 +199,34 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -187,6 +199,34 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0'
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
p_mask
[
cls_index
]
=
0
span_is_impossible
=
example
.
is_impossible
start_position
=
0
end_position
=
0
if
is_training
and
not
span_is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
span
[
"start"
]
doc_end
=
span
[
"start"
]
+
span
[
"length"
]
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
start_position
=
cls_index
end_position
=
cls_index
span_is_impossible
=
True
else
:
if
sequence_a_is_doc
:
doc_offset
=
0
else
:
doc_offset
=
len
(
truncated_query
)
+
sequence_added_tokens
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
new_features
.
append
(
NewSquadFeatures
(
new_features
.
append
(
NewSquadFeatures
(
span
[
'input_ids'
],
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'attention_mask'
],
...
@@ -199,7 +239,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -199,7 +239,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
paragraph_len
=
span
[
'paragraph_len'
],
paragraph_len
=
span
[
'paragraph_len'
],
token_is_max_context
=
span
[
"token_is_max_context"
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
]
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
))
))
unique_id
+=
1
unique_id
+=
1
...
@@ -207,86 +250,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -207,86 +250,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return
new_features
return
new_features
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
class
SquadProcessor
(
DataProcessor
):
"""Read a SQuAD json file into a list of SquadExample."""
with
open
(
input_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
def
is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
False
examples
=
[]
for
entry
in
input_data
:
for
paragraph
in
entry
[
"paragraphs"
]:
paragraph_text
=
paragraph
[
"context"
]
doc_tokens
=
[]
char_to_word_offset
=
[]
prev_is_whitespace
=
True
for
c
in
paragraph_text
:
if
is_whitespace
(
c
):
prev_is_whitespace
=
True
else
:
if
prev_is_whitespace
:
doc_tokens
.
append
(
c
)
else
:
doc_tokens
[
-
1
]
+=
c
prev_is_whitespace
=
False
char_to_word_offset
.
append
(
len
(
doc_tokens
)
-
1
)
for
qa
in
paragraph
[
"qas"
]:
qas_id
=
qa
[
"id"
]
question_text
=
qa
[
"question"
]
start_position
=
None
end_position
=
None
orig_answer_text
=
None
is_impossible
=
False
if
is_training
:
if
version_2_with_negative
:
is_impossible
=
qa
[
"is_impossible"
]
if
(
len
(
qa
[
"answers"
])
!=
1
)
and
(
not
is_impossible
):
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
if
not
is_impossible
:
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_length
=
len
(
orig_answer_text
)
start_position
=
char_to_word_offset
[
answer_offset
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
else
:
start_position
=
-
1
end_position
=
-
1
orig_answer_text
=
""
example
=
SquadExample
(
qas_id
=
qas_id
,
question_text
=
question_text
,
doc_tokens
=
doc_tokens
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
end_position
=
end_position
,
is_impossible
=
is_impossible
)
examples
.
append
(
example
)
return
examples
class
SquadV1Processor
(
DataProcessor
):
"""Processor for the SQuAD data set."""
"""Processor for the SQuAD data set."""
train_file
=
None
dev_file
=
None
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
def
get_example_from_tensor_dict
(
self
,
tensor_dict
):
"""See base class."""
"""See base class."""
...
@@ -301,13 +268,19 @@ class SquadV1Processor(DataProcessor):
...
@@ -301,13 +268,19 @@ class SquadV1Processor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
,
only_first
=
None
):
def
get_train_examples
(
self
,
data_dir
,
only_first
=
None
):
"""See base class."""
"""See base class."""
with
open
(
os
.
path
.
join
(
data_dir
,
"train-v1.1.json"
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
if
self
.
train_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
train_file
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"train"
,
only_first
)
return
self
.
_create_examples
(
input_data
,
"train"
,
only_first
)
def
get_dev_examples
(
self
,
data_dir
,
only_first
=
None
):
def
get_dev_examples
(
self
,
data_dir
,
only_first
=
None
):
"""See base class."""
"""See base class."""
with
open
(
os
.
path
.
join
(
data_dir
,
"dev-v1.1.json"
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
,
only_first
)
return
self
.
_create_examples
(
input_data
,
"dev"
,
only_first
)
...
@@ -329,7 +302,13 @@ class SquadV1Processor(DataProcessor):
...
@@ -329,7 +302,13 @@ class SquadV1Processor(DataProcessor):
question_text
=
qa
[
"question"
]
question_text
=
qa
[
"question"
]
start_position_character
=
None
start_position_character
=
None
answer_text
=
None
answer_text
=
None
if
is_training
:
if
"is_impossible"
in
qa
:
is_impossible
=
qa
[
"is_impossible"
]
else
:
is_impossible
=
False
if
not
is_impossible
and
is_training
:
if
(
len
(
qa
[
"answers"
])
!=
1
):
if
(
len
(
qa
[
"answers"
])
!=
1
):
raise
ValueError
(
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
"For training, each question should have exactly 1 answer."
)
...
@@ -343,14 +322,24 @@ class SquadV1Processor(DataProcessor):
...
@@ -343,14 +322,24 @@ class SquadV1Processor(DataProcessor):
context_text
=
context_text
,
context_text
=
context_text
,
answer_text
=
answer_text
,
answer_text
=
answer_text
,
start_position_character
=
start_position_character
,
start_position_character
=
start_position_character
,
title
=
title
title
=
title
,
is_impossible
=
is_impossible
)
)
examples
.
append
(
example
)
examples
.
append
(
example
)
if
only_first
is
not
None
and
len
(
examples
)
>
only_first
:
if
only_first
is
not
None
and
len
(
examples
)
>
only_first
:
return
examples
return
examples
return
examples
return
examples
class
SquadV1Processor
(
SquadProcessor
):
train_file
=
"train-v1.1.json"
dev_file
=
"dev-v1.1.json"
class
SquadV2Processor
(
SquadProcessor
):
train_file
=
"train-v2.0.json"
dev_file
=
"dev-v2.0.json"
class
NewSquadExample
(
object
):
class
NewSquadExample
(
object
):
...
@@ -364,13 +353,16 @@ class NewSquadExample(object):
...
@@ -364,13 +353,16 @@ class NewSquadExample(object):
context_text
,
context_text
,
answer_text
,
answer_text
,
start_position_character
,
start_position_character
,
title
):
title
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
question_text
=
question_text
self
.
context_text
=
context_text
self
.
context_text
=
context_text
self
.
answer_text
=
answer_text
self
.
answer_text
=
answer_text
self
.
title
=
title
self
.
title
=
title
self
.
is_impossible
=
False
self
.
is_impossible
=
is_impossible
self
.
start_position
,
self
.
end_position
=
0
,
0
doc_tokens
=
[]
doc_tokens
=
[]
char_to_word_offset
=
[]
char_to_word_offset
=
[]
...
@@ -392,7 +384,7 @@ class NewSquadExample(object):
...
@@ -392,7 +384,7 @@ class NewSquadExample(object):
self
.
char_to_word_offset
=
char_to_word_offset
self
.
char_to_word_offset
=
char_to_word_offset
# Start end end positions only has a value during evaluation.
# Start end end positions only has a value during evaluation.
if
start_position_character
is
not
None
:
if
start_position_character
is
not
None
and
not
is_impossible
:
self
.
start_position
=
char_to_word_offset
[
start_position_character
]
self
.
start_position
=
char_to_word_offset
[
start_position_character
]
self
.
end_position
=
char_to_word_offset
[
start_position_character
+
len
(
answer_text
)
-
1
]
self
.
end_position
=
char_to_word_offset
[
start_position_character
+
len
(
answer_text
)
-
1
]
...
@@ -415,7 +407,10 @@ class NewSquadFeatures(object):
...
@@ -415,7 +407,10 @@ class NewSquadFeatures(object):
paragraph_len
,
paragraph_len
,
token_is_max_context
,
token_is_max_context
,
tokens
,
tokens
,
token_to_orig_map
token_to_orig_map
,
start_position
,
end_position
):
):
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
attention_mask
=
attention_mask
...
@@ -430,6 +425,9 @@ class NewSquadFeatures(object):
...
@@ -430,6 +425,9 @@ class NewSquadFeatures(object):
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
token_to_orig_map
=
token_to_orig_map
self
.
token_to_orig_map
=
token_to_orig_map
self
.
start_position
=
start_position
self
.
end_position
=
end_position
class
SquadExample
(
object
):
class
SquadExample
(
object
):
"""
"""
A single training/test example for the Squad dataset.
A single training/test example for the Squad dataset.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment