utils_lm.py 787 Bytes
Newer Older
1
2
3
4
5
6
7
8
from torch.utils.data import Dataset, DataLoader
import os
import random
import torch
import torch.nn.functional as F


class WikiTextDataset(Dataset):
9
	def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024):
10
11
12
13
14
15
		self.max_context_length = max_context_length

		self.examples = []

		with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
			text = f.read()
16
			tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
17

18
19
20
			while len(tokenized_text) > max_context_length:
				self.examples.append(tokenized_text[:max_context_length])
				tokenized_text = tokenized_text[max_context_length:]
21
22
23
24
25

	def __len__(self):
		return len(self.examples)

	def __getitem__(self, item):
26
		return torch.tensor(self.examples[item])