Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aa90e0c3
Commit
aa90e0c3
authored
Feb 01, 2019
by
joe dumoulin
Browse files
fix prediction on run-squad.py example
parent
8f8bbd4a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
examples/run_squad.py
examples/run_squad.py
+7
-4
No files found.
examples/run_squad.py
View file @
aa90e0c3
...
...
@@ -706,7 +706,7 @@ def main():
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3.0
,
type
=
float
,
help
=
"Total number of training epochs to perform."
)
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
help
=
"Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%
%
"
help
=
"Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
"of training."
)
parser
.
add_argument
(
"--n_best_size"
,
default
=
20
,
type
=
int
,
help
=
"The total number of n-best predictions to generate in the nbest_predictions.json "
...
...
@@ -922,6 +922,9 @@ def main():
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
else
:
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
model
.
to
(
device
)
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment