"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "00337e9687abd7bd3ff55609dc87f92fc4443aac"
Commit aa90e0c3 authored by joe dumoulin's avatar joe dumoulin
Browse files

fix prediction on run-squad.py example

parent 8f8bbd4a
...@@ -706,7 +706,7 @@ def main(): ...@@ -706,7 +706,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
"of training.") "of training.")
parser.add_argument("--n_best_size", default=20, type=int, parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json " help="The total number of n-best predictions to generate in the nbest_predictions.json "
...@@ -919,9 +919,12 @@ def main(): ...@@ -919,9 +919,12 @@ def main():
if args.do_train: if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file) model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
model.to(device) model.to(device)
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment