Commit 934d3f4d authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Syncing up argument names between the scripts

parent f19ba35b
......@@ -201,8 +201,8 @@ def create_instances_from_document(
def main():
parser = ArgumentParser()
parser.add_argument('--corpus_path', type=Path, required=True)
parser.add_argument("--save_dir", type=Path, required=True)
parser.add_argument('--train_corpus', type=Path, required=True)
parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
......@@ -229,7 +229,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
with args.corpus_path.open() as f:
with args.train_corpus.open() as f:
docs = []
doc = []
for line in tqdm(f, desc="Loading Dataset"):
......@@ -241,7 +241,7 @@ def main():
tokens = tokenizer.tokenize(line)
doc.append(tokens)
args.save_dir.mkdir(exist_ok=True)
args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
......@@ -256,8 +256,8 @@ def main():
epoch_instances.extend(doc_instances)
shuffle(epoch_instances)
epoch_file = args.save_dir / f"epoch_{epoch}.json"
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json"
epoch_file = args.output_dir / f"epoch_{epoch}.json"
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with epoch_file.open('w') as out_file:
for instance in epoch_instances:
out_file.write(instance + '\n')
......
......@@ -401,7 +401,7 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--train_file",
parser.add_argument("--train_corpus",
default=None,
type=str,
required=True,
......@@ -511,8 +511,8 @@ def main():
#train_examples = None
num_train_optimization_steps = None
if args.do_train:
print("Loading Train Dataset", args.train_file)
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
print("Loading Train Dataset", args.train_corpus)
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
corpus_lines=None, on_memory=args.on_memory)
num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment