Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e9526b4
Commit
8e9526b4
authored
Dec 14, 2019
by
erenup
Browse files
add multiple processing
parent
9b312f9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
187 additions
and
160 deletions
+187
-160
examples/run_squad.py
examples/run_squad.py
+4
-1
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+183
-159
No files found.
examples/run_squad.py
View file @
8e9526b4
...
@@ -360,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -360,7 +360,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
doc_stride
=
args
.
doc_stride
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
is_training
=
not
evaluate
,
return_dataset
=
'pt'
return_dataset
=
'pt'
,
threads
=
args
.
threads
,
)
)
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
...
@@ -478,6 +479,8 @@ def main():
...
@@ -478,6 +479,8 @@ def main():
"See details at https://nvidia.github.io/apex/amp.html"
)
"See details at https://nvidia.github.io/apex/amp.html"
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--threads'
,
type
=
int
,
default
=
1
,
help
=
'multiple threads for converting example to features'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
...
...
transformers/data/processors/squad.py
View file @
8e9526b4
...
@@ -4,6 +4,9 @@ import logging
...
@@ -4,6 +4,9 @@ import logging
import
os
import
os
import
json
import
json
import
numpy
as
np
import
numpy
as
np
from
multiprocessing
import
Pool
from
multiprocessing
import
cpu_count
from
functools
import
partial
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
...
@@ -76,9 +79,168 @@ def _is_whitespace(c):
...
@@ -76,9 +79,168 @@ def _is_whitespace(c):
return
True
return
True
return
False
return
False
def
squad_convert_example_to_features
(
example
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
features
=
[]
if
is_training
and
not
example
.
is_impossible
:
# Get start and end position
start_position
=
example
.
start_position
end_position
=
example
.
end_position
# If the answer cannot be found in the text, then skip this example.
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
return
[]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
answer_text
)
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
\
if
'roberta'
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
span_doc_tokens
=
all_doc_tokens
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
):
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
if
tokenizer
.
padding_side
==
"right"
else
span_doc_tokens
,
span_doc_tokens
if
tokenizer
.
padding_side
==
"right"
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
pad_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
if
tokenizer
.
padding_side
==
"right"
else
'only_first'
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'input_ids'
]:
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'input_ids'
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
tokenizer
.
padding_side
==
"right"
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"token_to_orig_map"
]
=
token_to_orig_map
encoded_dict
[
"truncated_query_with_special_tokens_length"
]
=
len
(
truncated_query
)
+
sequence_added_tokens
encoded_dict
[
"token_is_max_context"
]
=
{}
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"length"
]
=
paragraph_len
spans
.
append
(
encoded_dict
)
if
"overflowing_tokens"
not
in
encoded_dict
:
break
span_doc_tokens
=
encoded_dict
[
"overflowing_tokens"
]
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
p_mask
=
np
.
minimum
(
p_mask
,
1
)
if
tokenizer
.
padding_side
==
"right"
:
# Limit positive values to one
p_mask
=
1
-
p_mask
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
1
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
span_is_impossible
=
example
.
is_impossible
start_position
=
0
end_position
=
0
if
is_training
and
not
span_is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
span
[
"start"
]
doc_end
=
span
[
"start"
]
+
span
[
"length"
]
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
start_position
=
cls_index
end_position
=
cls_index
span_is_impossible
=
True
else
:
if
tokenizer
.
padding_side
==
"left"
:
doc_offset
=
0
else
:
doc_offset
=
len
(
truncated_query
)
+
sequence_added_tokens
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
features
.
append
(
SquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'token_type_ids'
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
0
,
unique_id
=
0
,
paragraph_len
=
span
[
'paragraph_len'
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
))
return
features
def
squad_convert_example_to_features_init
(
tokenizer_for_convert
):
global
tokenizer
tokenizer
=
tokenizer_for_convert
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
return_dataset
=
False
,
threads
=
1
):
"""
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
@@ -93,6 +255,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -93,6 +255,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return_dataset: Default False. Either 'pt' or 'tf'.
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
if 'tf': returns a tf.data.Dataset
threads: multiple processing threadsa-smi
Returns:
Returns:
list of :class:`~transformers.data.processors.squad.SquadFeatures`
list of :class:`~transformers.data.processors.squad.SquadFeatures`
...
@@ -113,165 +277,26 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -113,165 +277,26 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
"""
"""
# Defining helper methods
# Defining helper methods
unique_id
=
1000000000
features
=
[]
features
=
[]
for
(
example_index
,
example
)
in
enumerate
(
tqdm
(
examples
)):
threads
=
min
(
threads
,
cpu_count
())
if
is_training
and
not
example
.
is_impossible
:
with
Pool
(
threads
,
initializer
=
squad_convert_example_to_features_init
,
initargs
=
(
tokenizer
,))
as
p
:
# Get start and end position
annotate_
=
partial
(
squad_convert_example_to_features
,
max_seq_length
=
max_seq_length
,
start_position
=
example
.
start_position
doc_stride
=
doc_stride
,
max_query_length
=
max_query_length
,
is_training
=
is_training
)
end_position
=
example
.
end_position
features
=
list
(
tqdm
(
p
.
imap
(
annotate_
,
examples
,
chunksize
=
32
),
total
=
len
(
examples
),
desc
=
'convert squad examples to features'
))
new_features
=
[]
# If the answer cannot be found in the text, then skip this example.
unique_id
=
1000000000
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:(
end_position
+
1
)])
example_index
=
0
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
for
example_features
in
tqdm
(
features
,
total
=
len
(
features
),
desc
=
'add example index and unique id'
):
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
if
not
example_features
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
continue
for
example_feature
in
example_features
:
example_feature
.
example_index
=
example_index
example_feature
.
unique_id
=
unique_id
tok_to_orig_index
=
[]
new_features
.
append
(
example_feature
)
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
answer_text
)
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
+
1
\
if
'roberta'
in
str
(
type
(
tokenizer
))
else
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
span_doc_tokens
=
all_doc_tokens
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
):
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
if
tokenizer
.
padding_side
==
"right"
else
span_doc_tokens
,
span_doc_tokens
if
tokenizer
.
padding_side
==
"right"
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
pad_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'only_second'
if
tokenizer
.
padding_side
==
"right"
else
'only_first'
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'input_ids'
]:
non_padded_ids
=
encoded_dict
[
'input_ids'
][:
encoded_dict
[
'input_ids'
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'input_ids'
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
tokenizer
.
padding_side
==
"right"
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
encoded_dict
[
"tokens"
]
=
tokens
encoded_dict
[
"token_to_orig_map"
]
=
token_to_orig_map
encoded_dict
[
"truncated_query_with_special_tokens_length"
]
=
len
(
truncated_query
)
+
sequence_added_tokens
encoded_dict
[
"token_is_max_context"
]
=
{}
encoded_dict
[
"start"
]
=
len
(
spans
)
*
doc_stride
encoded_dict
[
"length"
]
=
paragraph_len
spans
.
append
(
encoded_dict
)
if
"overflowing_tokens"
not
in
encoded_dict
:
break
span_doc_tokens
=
encoded_dict
[
"overflowing_tokens"
]
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'input_ids'
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'token_type_ids'
])
p_mask
=
np
.
minimum
(
p_mask
,
1
)
if
tokenizer
.
padding_side
==
"right"
:
# Limit positive values to one
p_mask
=
1
-
p_mask
p_mask
[
np
.
where
(
np
.
array
(
span
[
"input_ids"
])
==
tokenizer
.
sep_token_id
)[
0
]]
=
1
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
span_is_impossible
=
example
.
is_impossible
start_position
=
0
end_position
=
0
if
is_training
and
not
span_is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
span
[
"start"
]
doc_end
=
span
[
"start"
]
+
span
[
"length"
]
-
1
out_of_span
=
False
if
not
(
tok_start_position
>=
doc_start
and
tok_end_position
<=
doc_end
):
out_of_span
=
True
if
out_of_span
:
start_position
=
cls_index
end_position
=
cls_index
span_is_impossible
=
True
else
:
if
tokenizer
.
padding_side
==
"left"
:
doc_offset
=
0
else
:
doc_offset
=
len
(
truncated_query
)
+
sequence_added_tokens
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
features
.
append
(
SquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'token_type_ids'
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
example_index
,
unique_id
=
unique_id
,
paragraph_len
=
span
[
'paragraph_len'
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
))
unique_id
+=
1
unique_id
+=
1
example_index
+=
1
features
=
new_features
del
new_features
if
return_dataset
==
'pt'
:
if
return_dataset
==
'pt'
:
if
not
is_torch_available
():
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
...
@@ -295,7 +320,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -295,7 +320,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
all_cls_index
,
all_p_mask
)
all_cls_index
,
all_p_mask
)
return
features
,
dataset
return
features
,
dataset
return
features
return
features
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment