Commit d3fcec1a authored by thomwolf's avatar thomwolf
Browse files

add saving and loading model in examples

parent 93f335ef
...@@ -546,6 +546,15 @@ def main(): ...@@ -546,6 +546,15 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
...@@ -593,10 +602,6 @@ def main(): ...@@ -593,10 +602,6 @@ def main():
'global_step': global_step, 'global_step': global_step,
'loss': tr_loss/nb_tr_steps} 'loss': tr_loss/nb_tr_steps}
model_to_save = model.module if hasattr(model, 'module') else model
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save, output_model_file)
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****") logger.info("***** Eval results *****")
......
...@@ -911,6 +911,15 @@ def main(): ...@@ -911,6 +911,15 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False) input_file=args.predict_file, is_training=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment