Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7296f101
Commit
7296f101
authored
Dec 12, 2019
by
LysandreJik
Browse files
Cleanup squad and add allow train_file and predict_file usage
parent
5d67aa21
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
8 deletions
+20
-8
examples/run_squad.py
examples/run_squad.py
+14
-8
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+6
-0
No files found.
examples/run_squad.py
View file @
7296f101
...
...
@@ -337,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
input_dir
)
if
not
args
.
data_dir
:
if
not
args
.
data_dir
and
((
evaluate
and
not
args
.
predict_file
)
or
(
not
evaluate
and
not
args
.
train_file
))
:
try
:
import
tensorflow_datasets
as
tfds
except
ImportError
:
...
...
@@ -350,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples
=
SquadV1Processor
().
get_examples_from_dataset
(
tfds_examples
,
evaluate
=
evaluate
)
else
:
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
if
evaluate
:
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
,
filename
=
args
.
predict_file
)
else
:
examples
=
processor
.
get_train_examples
(
args
.
data_dir
,
filename
=
args
.
train_file
)
features
,
dataset
=
squad_convert_examples_to_features
(
examples
=
examples
,
...
...
@@ -387,7 +391,14 @@ def main():
## Other parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
help
=
"The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
)
help
=
"The input data dir. Should contain the .json files for the task."
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"The input training file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
help
=
"The input evaluation file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
...
...
@@ -472,11 +483,6 @@ def main():
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
args
.
predict_file
=
os
.
path
.
join
(
args
.
output_dir
,
'predictions_{}_{}.txt'
.
format
(
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
))
)
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
...
...
transformers/data/processors/squad.py
View file @
7296f101
...
...
@@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor):
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
"""
if
data_dir
is
None
:
data_dir
=
""
if
self
.
train_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
...
...
@@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor):
filename: None by default, specify this if the evaluation file has a different name than the original one
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
"""
if
data_dir
is
None
:
data_dir
=
""
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment