Commit 5652f54a authored by Lysandre's avatar Lysandre
Browse files

Simplified data generator + better perplexity calculator

GPT-2 now obtains ~20 perplexity on WikiText-2
parent 71553480
...@@ -85,7 +85,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -85,7 +85,7 @@ def train(args, train_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=WikiTextDataset.collate) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
...@@ -209,7 +209,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -209,7 +209,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=WikiTextDataset.collate) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
...@@ -217,12 +217,13 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -217,12 +217,13 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0 eval_loss = 0.0
nb_eval_steps = 0 nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"): for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = batch.to(args.device) batch = batch.to(args.device)
with torch.no_grad(): with torch.no_grad():
outputs = model(batch) outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch)
lm_loss = outputs[0] lm_loss = outputs[0]
eval_loss += lm_loss.mean().item() eval_loss += lm_loss.mean().item()
nb_eval_steps += 1 nb_eval_steps += 1
......
...@@ -6,34 +6,21 @@ import torch.nn.functional as F ...@@ -6,34 +6,21 @@ import torch.nn.functional as F
class WikiTextDataset(Dataset): class WikiTextDataset(Dataset):
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512): def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024):
self.max_context_length = max_context_length self.max_context_length = max_context_length
self.examples = [] self.examples = []
with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f: with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
text = f.read() text = f.read()
spans = list(filter(lambda item: len(item) > 120, text.split("\n"))) tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
for span in spans: while len(tokenized_text) > max_context_length:
span = tokenizer.encode(span) self.examples.append(tokenized_text[:max_context_length])
while len(span) > 0: tokenized_text = tokenized_text[max_context_length:]
self.examples.append(span[:max_context_length])
span = span[max_context_length:]
# Randomly shuffle the examples array
random.shuffle(self.examples)
# Sort the array by example length.
self.examples.sort(key=len)
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
def __getitem__(self, item): def __getitem__(self, item):
return torch.tensor(self.examples[item]) return torch.tensor(self.examples[item])
@staticmethod
def collate(values):
stack = torch.stack([F.pad(value, (len(values[-1]) - value.size(0), 0), "constant", 0) for value in values])
return stack
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment