"tests/utils/test_doc_samples.py" did not exist on "ff9e79ba3a3dd35c1a7edbd669cf78e082b2f7dc"
Commit be3b9bcf authored by Jade Abbott's avatar Jade Abbott
Browse files

Allow one to use the pretrained model in evaluation when do_train is not selected

parent 8da280eb
...@@ -431,7 +431,7 @@ def main(): ...@@ -431,7 +431,7 @@ def main():
if not args.do_train and not args.do_eval: if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.") raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
...@@ -554,6 +554,7 @@ def main(): ...@@ -554,6 +554,7 @@ def main():
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment