"vscode:/vscode.git/clone" did not exist on "df536438073178da2940d6a36bdd9360fb7f4fc3"
Commit b13abfa9 authored by thomwolf's avatar thomwolf
Browse files

add saving and loading model in examples

parent 270fa2f2
......@@ -487,8 +487,8 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank) # for distributed learning
model = BertForSequenceClassification.from_pretrained(args.bert_model, cache_dir=cache_dir)
if args.fp16:
model.half()
model.to(device)
......@@ -579,6 +579,15 @@ def main():
model.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
......@@ -626,10 +635,6 @@ def main():
'global_step': global_step,
'loss': tr_loss/nb_tr_steps}
model_to_save = model.module if hasattr(model, 'module') else model
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save, output_model_file)
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
......
......@@ -933,6 +933,15 @@ def main():
model.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment