Unverified Commit 766c6b2c authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #159 from jaderabbit/master

Allow do_eval to be used without do_train and to use the pretrained model in the output folder
parents 77966a43 193e2df8
...@@ -432,7 +432,7 @@ def main(): ...@@ -432,7 +432,7 @@ def main():
if not args.do_train and not args.do_eval: if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.") raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
...@@ -504,6 +504,8 @@ def main(): ...@@ -504,6 +504,8 @@ def main():
t_total=t_total) t_total=t_total)
global_step = 0 global_step = 0
nb_tr_steps = 0
tr_loss = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer) train_examples, label_list, args.max_seq_length, tokenizer)
...@@ -555,7 +557,8 @@ def main(): ...@@ -555,7 +557,8 @@ def main():
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file) if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file) model_state_dict = torch.load(output_model_file)
...@@ -581,7 +584,8 @@ def main(): ...@@ -581,7 +584,8 @@ def main():
model.eval() model.eval()
eval_loss, eval_accuracy = 0, 0 eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
...@@ -603,11 +607,11 @@ def main(): ...@@ -603,11 +607,11 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
loss = tr_loss/nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss, result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy, 'eval_accuracy': eval_accuracy,
'global_step': global_step, 'global_step': global_step,
'loss': tr_loss/nb_tr_steps} 'loss': loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment