"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8f8bbd4a4c7db98561048c2c0dd7501c57abda96"
Commit 0ae59e66 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Reduced memory usage for pregenerating the data a lot by writing it

out on the fly without shuffling - the Sampler in the finetuning script
will shuffle for us.
parent 6a9038ba
...@@ -73,7 +73,10 @@ class PregeneratedDataset(Dataset): ...@@ -73,7 +73,10 @@ class PregeneratedDataset(Dataset):
logging.info(f"Loading training examples for epoch {epoch}") logging.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f: with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")): for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
example = json.loads(line.rstrip()) line = line.strip()
if not line:
continue # Skip trailing blank lines etc.
example = json.loads(line)
features = convert_example_to_features(example, tokenizer, seq_len) features = convert_example_to_features(example, tokenizer, seq_len)
input_ids[i] = features.input_ids input_ids[i] = features.input_ids
segment_ids[i] = features.segment_ids segment_ids[i] = features.segment_ids
......
...@@ -242,24 +242,22 @@ def main(): ...@@ -242,24 +242,22 @@ def main():
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain # When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs # Google BERT doesn't do this, and as a result oversamples shorter docs
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_instances = [] epoch_filename = args.output_dir / f"epoch_{epoch}.json"
for doc_idx in trange(len(docs), desc="Document"): num_instances = 0
doc_instances = create_instances_from_document( with epoch_filename.open('w') as epoch_file:
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, for doc_idx in trange(len(docs), desc="Document"):
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, doc_instances = create_instances_from_document(
vocab_list=vocab_list) docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
doc_instances = [json.dumps(instance) for instance in doc_instances] masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
epoch_instances.extend(doc_instances) vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
shuffle(epoch_instances) for instance in doc_instances:
epoch_file = args.output_dir / f"epoch_{epoch}.json" epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with epoch_file.open('w') as out_file:
for instance in epoch_instances:
out_file.write(instance + '\n')
with metrics_file.open('w') as metrics_file: with metrics_file.open('w') as metrics_file:
metrics = { metrics = {
"num_training_examples": len(epoch_instances), "num_training_examples": num_instances,
"max_seq_len": args.max_seq_len "max_seq_len": args.max_seq_len
} }
metrics_file.write(json.dumps(metrics)) metrics_file.write(json.dumps(metrics))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment