Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ea52f824
"git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "967a259ceeebe8917dec488303d084398e2f462a"
Commit
ea52f824
authored
Nov 18, 2019
by
Lysandre
Committed by
LysandreJik
Nov 22, 2019
Browse files
Moved some SQuAD logic to /data
parent
26db31e0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
323 additions
and
2 deletions
+323
-2
transformers/__init__.py
transformers/__init__.py
+2
-1
transformers/data/__init__.py
transformers/data/__init__.py
+2
-1
transformers/data/processors/__init__.py
transformers/data/processors/__init__.py
+1
-0
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+318
-0
No files found.
transformers/__init__.py
View file @
ea52f824
...
...
@@ -25,7 +25,8 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
from
.data
import
(
is_sklearn_available
,
InputExample
,
InputFeatures
,
DataProcessor
,
glue_output_modes
,
glue_convert_examples_to_features
,
glue_processors
,
glue_tasks_num_labels
)
glue_processors
,
glue_tasks_num_labels
,
squad_convert_examples_to_features
,
SquadFeatures
)
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
...
...
transformers/data/__init__.py
View file @
ea52f824
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
from
.processors
import
InputExample
,
InputFeatures
,
DataProcessor
,
SquadFeatures
from
.processors
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.processors
import
squad_convert_examples_to_features
from
.metrics
import
is_sklearn_available
if
is_sklearn_available
():
...
...
transformers/data/processors/__init__.py
View file @
ea52f824
from
.utils
import
InputExample
,
InputFeatures
,
DataProcessor
from
.glue
import
glue_output_modes
,
glue_processors
,
glue_tasks_num_labels
,
glue_convert_examples_to_features
from
.squad
import
squad_convert_examples_to_features
,
SquadFeatures
transformers/data/processors/squad.py
0 → 100644
View file @
ea52f824
from
tqdm
import
tqdm
import
collections
import
logging
import
os
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
...file_utils
import
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
logger
=
logging
.
getLogger
(
__name__
)
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
sep_token
=
'[SEP]'
,
pad_token
=
0
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
cls_token_segment_id
=
0
,
pad_token_segment_id
=
0
,
mask_padding_with_zero
=
True
,
sequence_a_is_doc
=
False
):
"""Loads a data file into a list of `InputBatch`s."""
# Defining helper methods
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:(
new_end
+
1
)])
if
text_span
==
tok_answer_text
:
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
best_score
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
position
<
doc_span
.
start
:
continue
if
position
>
end
:
continue
num_left_context
=
position
-
doc_span
.
start
num_right_context
=
end
-
position
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
.
length
if
best_score
is
None
or
score
>
best_score
:
best_score
=
score
best_span_index
=
span_index
return
cur_span_index
==
best_span_index
unique_id
=
1000000000
features
=
[]
for
(
example_index
,
example
)
in
enumerate
(
tqdm
(
examples
)):
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
tok_start_position
=
None
tok_end_position
=
None
if
is_training
and
example
.
is_impossible
:
tok_start_position
=
-
1
tok_end_position
=
-
1
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
orig_answer_text
)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"DocSpan"
,
[
"start"
,
"length"
])
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
break
start_offset
+=
min
(
length
,
doc_stride
)
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
[]
# CLS token at the beginning
if
not
cls_token_at_end
:
tokens
.
append
(
cls_token
)
segment_ids
.
append
(
cls_token_segment_id
)
p_mask
.
append
(
0
)
cls_index
=
0
# XLNet: P SEP Q SEP CLS
# Others: CLS Q SEP P SEP
if
not
sequence_a_is_doc
:
# Query
tokens
+=
query_tokens
segment_ids
+=
[
sequence_a_segment_id
]
*
len
(
query_tokens
)
p_mask
+=
[
1
]
*
len
(
query_tokens
)
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
1
)
# Paragraph
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
if
not
sequence_a_is_doc
:
segment_ids
.
append
(
sequence_b_segment_id
)
else
:
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
0
)
paragraph_len
=
doc_span
.
length
if
sequence_a_is_doc
:
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_a_segment_id
)
p_mask
.
append
(
1
)
tokens
+=
query_tokens
segment_ids
+=
[
sequence_b_segment_id
]
*
len
(
query_tokens
)
p_mask
+=
[
1
]
*
len
(
query_tokens
)
# SEP token
tokens
.
append
(
sep_token
)
segment_ids
.
append
(
sequence_b_segment_id
)
p_mask
.
append
(
1
)
# CLS token at the end
if
cls_token_at_end
:
tokens
.
append
(
cls_token
)
segment_ids
.
append
(
cls_token_segment_id
)
p_mask
.
append
(
0
)
cls_index
=
len
(
tokens
)
-
1
# Index of classification token
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
pad_token
)
input_mask
.
append
(
0
if
mask_padding_with_zero
else
1
)
segment_ids
.
append
(
pad_token_segment_id
)
p_mask
.
append
(
1
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
span_is_impossible
=
example
.
is_impossible
start_position
=
None
end_position
=
None
if
is_training
and
not
span_is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
start_position
=
0
end_position
=
0
span_is_impossible
=
True
else
:
if
sequence_a_is_doc
:
doc_offset
=
0
else
:
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
span_is_impossible
:
start_position
=
cls_index
end_position
=
cls_index
if
example_index
<
20
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"unique_id: %s"
%
(
unique_id
))
logger
.
info
(
"example_index: %s"
%
(
example_index
))
logger
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
logger
.
info
(
"tokens: %s"
%
" "
.
join
(
tokens
))
logger
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
([
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
token_to_orig_map
.
items
()]))
logger
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
token_is_max_context
.
items
()
]))
logger
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logger
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
and
span_is_impossible
:
logger
.
info
(
"impossible example"
)
if
is_training
and
not
span_is_impossible
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
logger
.
info
(
"start_position: %d"
%
(
start_position
))
logger
.
info
(
"end_position: %d"
%
(
end_position
))
logger
.
info
(
"answer: %s"
%
(
answer_text
))
features
.
append
(
SquadFeatures
(
unique_id
=
unique_id
,
example_index
=
example_index
,
doc_span_index
=
doc_span_index
,
tokens
=
tokens
,
token_to_orig_map
=
token_to_orig_map
,
token_is_max_context
=
token_is_max_context
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
cls_index
=
cls_index
,
p_mask
=
p_mask
,
paragraph_len
=
paragraph_len
,
start_position
=
start_position
,
end_position
=
end_position
,
is_impossible
=
span_is_impossible
))
unique_id
+=
1
return
features
class
SquadFeatures
(
object
):
"""A single set of features of data."""
def
__init__
(
self
,
unique_id
,
example_index
,
doc_span_index
,
tokens
,
token_to_orig_map
,
token_is_max_context
,
input_ids
,
input_mask
,
segment_ids
,
cls_index
,
p_mask
,
paragraph_len
,
start_position
=
None
,
end_position
=
None
,
is_impossible
=
None
):
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
self
.
tokens
=
tokens
self
.
token_to_orig_map
=
token_to_orig_map
self
.
token_is_max_context
=
token_is_max_context
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
cls_index
=
cls_index
self
.
p_mask
=
p_mask
self
.
paragraph_len
=
paragraph_len
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
__eq__
(
self
,
other
):
return
self
.
cls_index
==
other
.
cls_index
and
\
self
.
doc_span_index
==
other
.
doc_span_index
and
\
self
.
end_position
==
other
.
end_position
and
\
self
.
example_index
==
other
.
example_index
and
\
self
.
input_ids
==
other
.
input_ids
and
\
self
.
input_mask
==
other
.
input_mask
and
\
self
.
is_impossible
==
other
.
is_impossible
and
\
self
.
p_mask
==
other
.
p_mask
and
\
self
.
paragraph_len
==
other
.
paragraph_len
and
\
self
.
segment_ids
==
other
.
segment_ids
and
\
self
.
start_position
==
other
.
start_position
and
\
self
.
token_is_max_context
==
other
.
token_is_max_context
and
\
self
.
token_to_orig_map
==
other
.
token_to_orig_map
and
\
self
.
tokens
==
other
.
tokens
and
\
self
.
unique_id
==
other
.
unique_id
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment