Unverified Commit aff44f0c authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into master

parents f4fc9c61 7e7e4753
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
# Label to use when marking an issue as stale
staleLabel: wontfix
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
\ No newline at end of file
...@@ -37,6 +37,7 @@ python3 simple_lm_finetuning.py ...@@ -37,6 +37,7 @@ python3 simple_lm_finetuning.py
--bert_model bert-base-uncased --bert_model bert-base-uncased
--do_lower_case --do_lower_case
--output_dir finetuned_lm/ --output_dir finetuned_lm/
--do_train
``` ```
### Pregenerating training data ### Pregenerating training data
...@@ -60,4 +61,4 @@ python3 finetune_on_pregenerated.py ...@@ -60,4 +61,4 @@ python3 finetune_on_pregenerated.py
--do_lower_case --do_lower_case
--output_dir finetuned_lm/ --output_dir finetuned_lm/
--epochs 3 --epochs 3
``` ```
\ No newline at end of file
...@@ -123,9 +123,8 @@ def main(): ...@@ -123,9 +123,8 @@ def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--pregenerated_data', type=Path, required=True) parser.add_argument('--pregenerated_data', type=Path, required=True)
parser.add_argument('--output_dir', type=Path, required=True) parser.add_argument('--output_dir', type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True, parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage") help="Store training data as on-disc memmaps to massively reduce memory usage")
......
...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange ...@@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randint, shuffle, choice, sample from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np import numpy as np
import json import json
...@@ -30,6 +30,8 @@ class DocumentDatabase: ...@@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory self.reduce_memory = reduce_memory
def add_document(self, document): def add_document(self, document):
if not document:
return
if self.reduce_memory: if self.reduce_memory:
current_idx = len(self.doc_lengths) current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document self.document_shelf[str(current_idx)] = document
...@@ -49,11 +51,11 @@ class DocumentDatabase: ...@@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights() self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx] rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else: else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen # If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1) sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx assert sampled_doc_index != current_idx
if self.reduce_memory: if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)] return self.document_shelf[str(sampled_doc_index)]
...@@ -170,7 +172,7 @@ def create_instances_from_document( ...@@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence. # (first) sentence.
a_end = 1 a_end = 1
if len(current_chunk) >= 2: if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1) a_end = randrange(1, len(current_chunk))
tokens_a = [] tokens_a = []
for j in range(a_end): for j in range(a_end):
...@@ -186,7 +188,7 @@ def create_instances_from_document( ...@@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently # Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1) random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)): for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j]) tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length: if len(tokens_b) >= target_b_length:
...@@ -264,6 +266,14 @@ def main(): ...@@ -264,6 +266,14 @@ def main():
else: else:
tokens = tokenizer.tokenize(line) tokens = tokenizer.tokenize(line)
doc.append(tokens) doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True) args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"): for epoch in trange(args.epochs_to_generate, desc="Epoch"):
......
...@@ -95,7 +95,7 @@ class DataProcessor(object): ...@@ -95,7 +95,7 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
......
...@@ -83,8 +83,9 @@ def run_model(): ...@@ -83,8 +83,9 @@ def run_model():
elif args.length > model.config.n_ctx: elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
if not args.unconditional: while True:
while True: context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
...@@ -123,6 +124,8 @@ def run_model(): ...@@ -123,6 +124,8 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text) print(text)
print("=" * 80) print("=" * 80)
if args.unconditional:
break
if __name__ == '__main__': if __name__ == '__main__':
run_model() run_model()
......
...@@ -930,7 +930,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -930,7 +930,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`) `extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
......
...@@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return return
# Update config # Update config
self.config.n_special = num_special_tokens self.config.n_special = num_special_tokens
# # Build new embeddings and initialize # Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed = self.tokens_embed old_embed = self.tokens_embed
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd) self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.tokens_embed) self.init_weights(self.tokens_embed)
# Copy word and positional embeddings from the previous weights # Copy word embeddings from the previous weights
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :] self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None: if position_ids is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment