Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce158a07
Commit
ce158a07
authored
Dec 04, 2019
by
LysandreJik
Browse files
Return dataset (pytorch)
parent
7a035199
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
4 deletions
+37
-4
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+37
-4
No files found.
transformers/data/processors/squad.py
View file @
ce158a07
...
...
@@ -7,7 +7,11 @@ import numpy as np
from
...tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
.utils
import
DataProcessor
,
InputExample
,
InputFeatures
from
...file_utils
import
is_tf_available
from
...file_utils
import
is_tf_available
,
is_torch_available
if
is_torch_available
:
import
torch
from
torch.utils.data
import
TensorDataset
if
is_tf_available
():
import
tensorflow
as
tf
...
...
@@ -73,7 +77,8 @@ def _is_whitespace(c):
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
...
@@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
max_seq_length: The maximum sequence length of the inputs.
doc_stride: The stride used when the context is too large and is split across several features.
max_query_length: The maximum length of the query.
is_training: wheter to create features for model evaluation or model training.
is_training: whether to create features for model evaluation or model training.
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
Returns:
list of :class:`~transformers.data.processors.squad.SquadFeatures`
...
...
@@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id
+=
1
if
return_dataset
==
'pt'
:
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
if
not
is_training
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
else
:
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
)
return
features
,
dataset
return
features
...
...
@@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor):
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
not
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment