Commit 1a8e87be authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Line-by-line text dataset (including padding)

parent b94cf7fa
......@@ -32,6 +32,7 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
......@@ -83,7 +84,7 @@ MODEL_CLASSES = {
class TextDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path="train", block_size=512):
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
......@@ -120,13 +121,32 @@ class TextDataset(Dataset):
return torch.tensor(self.examples[item])
class LineByLineTextDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
assert os.path.isfile(file_path)
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =)
logger.info("Creating features from dataset file at %s", file_path)
with open(file_path, encoding="utf-8") as f:
lines = [line for line in f.read().splitlines() if len(line) > 0]
self.examples = tokenizer.batch_encode_plus(lines, max_length=block_size)["input_ids"]
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return torch.tensor(self.examples[i])
def load_and_cache_examples(args, tokenizer, evaluate=False):
return TextDataset(
tokenizer,
args,
file_path=args.eval_data_file if evaluate else args.train_data_file,
block_size=args.block_size,
)
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
def set_seed(args):
......@@ -182,6 +202,8 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
padding_mask = labels.eq(tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
......@@ -204,8 +226,14 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
def collate(examples: List[torch.Tensor]):
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
)
if args.max_steps > 0:
t_total = args.max_steps
......@@ -391,8 +419,14 @@ def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefi
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
def collate(examples: List[torch.Tensor]):
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
)
# multi-gpu evaluate
if args.n_gpu > 1:
......@@ -456,11 +490,14 @@ def main():
type=str,
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
)
parser.add_argument(
"--line_by_line",
action="store_true",
help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
)
parser.add_argument(
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
)
parser.add_argument(
"--model_name_or_path",
default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment