"ts/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e349b4409ddf19a1356e17d758999065d121dc3d"
Commit 3a9a9f78 authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

default output dir to documents dir

parent 693606a7
...@@ -31,9 +31,7 @@ Batch = namedtuple( ...@@ -31,9 +31,7 @@ Batch = namedtuple(
def evaluate(args): def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained( model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
"bertabs-finetuned-{}".format(args.finetuned_model)
)
bertabs.to(args.device) bertabs.to(args.device)
bertabs.eval() bertabs.eval()
...@@ -195,8 +193,8 @@ def main(): ...@@ -195,8 +193,8 @@ def main():
"--summaries_output_dir", "--summaries_output_dir",
default=None, default=None,
type=str, type=str,
required=True, required=False,
help="The folder in wich the summaries should be written.", help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
) )
# EVALUATION options # EVALUATION options
parser.add_argument( parser.add_argument(
...@@ -242,6 +240,9 @@ def main(): ...@@ -242,6 +240,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda") args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir
if not documents_dir_is_valid(args.documents_dir): if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError( raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path." "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
......
...@@ -39,6 +39,8 @@ class SummarizationDataset(Dataset): ...@@ -39,6 +39,8 @@ class SummarizationDataset(Dataset):
self.documents = [] self.documents = []
story_filenames_list = os.listdir(path) story_filenames_list = os.listdir(path)
for story_filename in story_filenames_list: for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename) path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story): if not os.path.isfile(path_to_story):
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment