Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce158a07
Commit
ce158a07
authored
Dec 04, 2019
by
LysandreJik
Browse files
Return dataset (pytorch)
parent
7a035199
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
4 deletions
+37
-4
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+37
-4
No files found.
transformers/data/processors/squad.py
View file @
ce158a07
...
...
@@ -7,7 +7,11 @@ import numpy as np
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
...file_utils
import
is_tf_available
from
...file_utils
import
is_tf_available
,
is_torch_available
if
is_torch_available
:
import
torch
from
torch.utils.data
import
TensorDataset
if
is_tf_available
():
import
tensorflow
as
tf
...
...
@@ -73,7 +77,8 @@ def _is_whitespace(c):
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
...
@@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
max_seq_length: The maximum sequence length of the inputs.
doc_stride: The stride used when the context is too large and is split across several features.
max_query_length: The maximum length of the query.
is_training: wheter to create features for model evaluation or model training.
is_training: whether to create features for model evaluation or model training.
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
Returns:
list of :class:`~transformers.data.processors.squad.SquadFeatures`
...
...
@@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id
+=
1
if
return_dataset
==
'pt'
:
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
if
not
is_training
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
else
:
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
)
return
features
,
dataset
return
features
...
...
@@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor):
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
not
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment