"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "27c1b656cca75efa0cc414d3bf4e6aacf24829de"
Commit 934d3f4d authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Syncing up argument names between the scripts

parent f19ba35b
...@@ -201,8 +201,8 @@ def create_instances_from_document( ...@@ -201,8 +201,8 @@ def create_instances_from_document(
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--corpus_path', type=Path, required=True) parser.add_argument('--train_corpus', type=Path, required=True)
parser.add_argument("--save_dir", type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True, parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual", "bert-base-chinese"])
...@@ -229,7 +229,7 @@ def main(): ...@@ -229,7 +229,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys()) vocab_list = list(tokenizer.vocab.keys())
with args.corpus_path.open() as f: with args.train_corpus.open() as f:
docs = [] docs = []
doc = [] doc = []
for line in tqdm(f, desc="Loading Dataset"): for line in tqdm(f, desc="Loading Dataset"):
...@@ -241,7 +241,7 @@ def main(): ...@@ -241,7 +241,7 @@ def main():
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
args.save_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs) docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain # When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs # Google BERT doesn't do this, and as a result oversamples shorter docs
...@@ -256,8 +256,8 @@ def main(): ...@@ -256,8 +256,8 @@ def main():
epoch_instances.extend(doc_instances) epoch_instances.extend(doc_instances)
shuffle(epoch_instances) shuffle(epoch_instances)
epoch_file = args.save_dir / f"epoch_{epoch}.json" epoch_file = args.output_dir / f"epoch_{epoch}.json"
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json" metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with epoch_file.open('w') as out_file: with epoch_file.open('w') as out_file:
for instance in epoch_instances: for instance in epoch_instances:
out_file.write(instance + '\n') out_file.write(instance + '\n')
......
...@@ -401,7 +401,7 @@ def main(): ...@@ -401,7 +401,7 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--train_file", parser.add_argument("--train_corpus",
default=None, default=None,
type=str, type=str,
required=True, required=True,
...@@ -511,8 +511,8 @@ def main(): ...@@ -511,8 +511,8 @@ def main():
#train_examples = None #train_examples = None
num_train_optimization_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
print("Loading Train Dataset", args.train_file) print("Loading Train Dataset", args.train_corpus)
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
corpus_lines=None, on_memory=args.on_memory) corpus_lines=None, on_memory=args.on_memory)
num_train_optimization_steps = int( num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment