utils_lm.py 1.72 KB
Newer Older
1
2
3
4
5
from torch.utils.data import Dataset, DataLoader
import os
import random
import torch
import torch.nn.functional as F
6
7
8
9
import logging
import pickle

logger = logging.getLogger(__name__)
10
11
12


class WikiTextDataset(Dataset):
13
14
15
16
17
18
19
20
21
22
23
24
25
26
	def __init__(self, args, tokenizer, file='train', directory='wikitext', max_context_length=512, cache=None):
		if args.local_rank not in [-1, 0]:
			torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache
			
			
		cached_features_file = os.path.join(args.data_dir, f'cached_lm_{file}_{args.max_seq_length}')
		
		if os.path.exists(cached_features_file):
			logger.info("Loading features from cached file %s", cached_features_file)
			with open(cached_features_file, 'rb') as handle:
				self.examples = pickle.load(handle)
		else:
			logger.info("Creating features from dataset file at %s", args.data_dir)	
		
27
28
29
30
31
32
		self.max_context_length = max_context_length

		self.examples = []

		with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
			text = f.read()
33
			tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
34

35
36
37
			while len(tokenized_text) > max_context_length:
				self.examples.append(tokenized_text[:max_context_length])
				tokenized_text = tokenized_text[max_context_length:]
38
39
40
41
42
43
44
45
			
		if args.local_rank in [-1, 0]:
			logger.info("Saving features into cached file %s", cached_features_file)
			with open(cached_features_file, 'wb') as handle:
				pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
		
		if args.local_rank == 0:
			torch.distributed.barrier()
46
47
48
49
50

	def __len__(self):
		return len(self.examples)

	def __getitem__(self, item):
51
		return torch.tensor(self.examples[item])