Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7296f101
Commit
7296f101
authored
Dec 12, 2019
by
LysandreJik
Browse files
Cleanup squad and add allow train_file and predict_file usage
parent
5d67aa21
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
8 deletions
+20
-8
examples/run_squad.py
examples/run_squad.py
+14
-8
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+6
-0
No files found.
examples/run_squad.py
View file @
7296f101
...
@@ -337,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -337,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
else
:
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
input_dir
)
logger
.
info
(
"Creating features from dataset file at %s"
,
input_dir
)
if
not
args
.
data_dir
:
if
not
args
.
data_dir
and
((
evaluate
and
not
args
.
predict_file
)
or
(
not
evaluate
and
not
args
.
train_file
))
:
try
:
try
:
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
except
ImportError
:
except
ImportError
:
...
@@ -350,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -350,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
examples
=
SquadV1Processor
().
get_examples_from_dataset
(
tfds_examples
,
evaluate
=
evaluate
)
examples
=
SquadV1Processor
().
get_examples_from_dataset
(
tfds_examples
,
evaluate
=
evaluate
)
else
:
else
:
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
if
evaluate
:
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
,
filename
=
args
.
predict_file
)
else
:
examples
=
processor
.
get_train_examples
(
args
.
data_dir
,
filename
=
args
.
train_file
)
features
,
dataset
=
squad_convert_examples_to_features
(
features
,
dataset
=
squad_convert_examples_to_features
(
examples
=
examples
,
examples
=
examples
,
...
@@ -387,7 +391,14 @@ def main():
...
@@ -387,7 +391,14 @@ def main():
## Other parameters
## Other parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
help
=
"The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
)
help
=
"The input data dir. Should contain the .json files for the task."
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"The input training file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
help
=
"The input evaluation file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
...
@@ -472,11 +483,6 @@ def main():
...
@@ -472,11 +483,6 @@ def main():
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
predict_file
=
os
.
path
.
join
(
args
.
output_dir
,
'predictions_{}_{}.txt'
.
format
(
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
))
)
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
...
...
transformers/data/processors/squad.py
View file @
7296f101
...
@@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor):
...
@@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor):
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
"""
"""
if
data_dir
is
None
:
data_dir
=
""
if
self
.
train_file
is
None
:
if
self
.
train_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
...
@@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor):
...
@@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor):
filename: None by default, specify this if the evaluation file has a different name than the original one
filename: None by default, specify this if the evaluation file has a different name than the original one
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
"""
"""
if
data_dir
is
None
:
data_dir
=
""
if
self
.
dev_file
is
None
:
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment