Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fd338abd
Commit
fd338abd
authored
Apr 06, 2021
by
Sylvain Gugger
Browse files
Style
parent
aef4cf8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
19 deletions
+8
-19
examples/question-answering/run_qa_beam_search_no_trainer.py
examples/question-answering/run_qa_beam_search_no_trainer.py
+5
-12
examples/question-answering/run_qa_no_trainer.py
examples/question-answering/run_qa_no_trainer.py
+3
-7
No files found.
examples/question-answering/run_qa_beam_search_no_trainer.py
View file @
fd338abd
...
...
@@ -76,9 +76,7 @@ def parse_args():
parser
.
add_argument
(
"--preprocessing_num_workers"
,
type
=
int
,
default
=
4
,
help
=
"A csv or a json file containing the training data."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Eval the question answering model"
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Eval the question answering model"
)
parser
.
add_argument
(
"--validation_file"
,
type
=
str
,
default
=
None
,
help
=
"A csv or a json file containing the validation data."
)
...
...
@@ -284,7 +282,7 @@ def main():
# Preprocessing the datasets.
# Preprocessing is slighlty different for training and evaluation.
column_names
=
raw_datasets
[
"train"
].
column_names
question_column_name
=
"question"
if
"question"
in
column_names
else
column_names
[
0
]
context_column_name
=
"context"
if
"context"
in
column_names
else
column_names
[
1
]
answer_column_name
=
"answers"
if
"answers"
in
column_names
else
column_names
[
2
]
...
...
@@ -396,7 +394,6 @@ def main():
return
tokenized_examples
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
raw_datasets
[
"train"
]
...
...
@@ -481,7 +478,6 @@ def main():
return
tokenized_examples
if
"validation"
not
in
raw_datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_examples
=
raw_datasets
[
"validation"
]
...
...
@@ -539,11 +535,8 @@ def main():
train_dataset
,
shuffle
=
True
,
collate_fn
=
data_collator
,
batch_size
=
args
.
per_device_train_batch_size
)
eval_dataset
.
set_format
(
type
=
"torch"
,
columns
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
])
eval_dataloader
=
DataLoader
(
eval_dataset
,
collate_fn
=
data_collator
,
batch_size
=
args
.
per_device_eval_batch_size
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
collate_fn
=
data_collator
,
batch_size
=
args
.
per_device_eval_batch_size
)
if
args
.
do_predict
:
test_dataset
.
set_format
(
type
=
"torch"
,
columns
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
])
...
...
@@ -605,8 +598,8 @@ def main():
if
step
+
batch_size
<
len
(
dataset
):
logits_concat
[
step
:
step
+
batch_size
,
:
cols
]
=
output_logit
else
:
logits_concat
[
step
:,
:
cols
]
=
output_logit
[:
len
(
dataset
)
-
step
]
logits_concat
[
step
:,
:
cols
]
=
output_logit
[:
len
(
dataset
)
-
step
]
step
+=
batch_size
return
logits_concat
...
...
examples/question-answering/run_qa_no_trainer.py
View file @
fd338abd
...
...
@@ -81,9 +81,7 @@ def parse_args():
parser
.
add_argument
(
"--preprocessing_num_workers"
,
type
=
int
,
default
=
4
,
help
=
"A csv or a json file containing the training data."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Eval the question answering model"
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Eval the question answering model"
)
parser
.
add_argument
(
"--validation_file"
,
type
=
str
,
default
=
None
,
help
=
"A csv or a json file containing the validation data."
)
...
...
@@ -543,9 +541,7 @@ def main():
)
eval_dataset
.
set_format
(
type
=
"torch"
,
columns
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
])
eval_dataloader
=
DataLoader
(
eval_dataset
,
collate_fn
=
data_collator
,
batch_size
=
args
.
per_device_eval_batch_size
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
collate_fn
=
data_collator
,
batch_size
=
args
.
per_device_eval_batch_size
)
if
args
.
do_predict
:
test_dataset
.
set_format
(
type
=
"torch"
,
columns
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
])
...
...
@@ -607,7 +603,7 @@ def main():
if
step
+
batch_size
<
len
(
dataset
):
logits_concat
[
step
:
step
+
batch_size
,
:
cols
]
=
output_logit
else
:
logits_concat
[
step
:,
:
cols
]
=
output_logit
[:
len
(
dataset
)
-
step
]
logits_concat
[
step
:,
:
cols
]
=
output_logit
[:
len
(
dataset
)
-
step
]
step
+=
batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment