Commit 569897ce authored by Lysandre's avatar Lysandre Committed by Julien Chaumond
Browse files

Fix a few issues regarding the language modeling script

parent 21da8950
...@@ -130,9 +130,9 @@ class LineByLineTextDataset(Dataset): ...@@ -130,9 +130,9 @@ class LineByLineTextDataset(Dataset):
logger.info("Creating features from dataset file at %s", file_path) logger.info("Creating features from dataset file at %s", file_path)
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
lines = [line for line in f.read().splitlines() if len(line) > 0] lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
self.examples = tokenizer.batch_encode_plus(lines, max_length=block_size)["input_ids"] self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
...@@ -704,10 +704,10 @@ def main(): ...@@ -704,10 +704,10 @@ def main():
) )
if args.block_size <= 0: if args.block_size <= 0:
args.block_size = tokenizer.max_len_single_sentence args.block_size = tokenizer.max_len
# Our input block size will be the max possible for the model # Our input block size will be the max possible for the model
else: else:
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) args.block_size = min(args.block_size, tokenizer.max_len)
if args.model_name_or_path: if args.model_name_or_path:
model = model_class.from_pretrained( model = model_class.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment