Commit 3a9a9f78 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

default output dir to documents dir

parent 693606a7
......@@ -31,9 +31,7 @@ Batch = namedtuple(
def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained(
"bertabs-finetuned-{}".format(args.finetuned_model)
)
model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
bertabs.to(args.device)
bertabs.eval()
......@@ -195,8 +193,8 @@ def main():
"--summaries_output_dir",
default=None,
type=str,
required=True,
help="The folder in wich the summaries should be written.",
required=False,
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
)
# EVALUATION options
parser.add_argument(
......@@ -242,6 +240,9 @@ def main():
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
......
......@@ -39,6 +39,8 @@ class SummarizationDataset(Dataset):
self.documents = []
story_filenames_list = os.listdir(path)
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment