data.py 1.25 KB
Newer Older
1
2
import torch
from datasets import load_dataset
3
from torch.utils.data import Dataset
4
5
6
7
8
9
10
11
12
13


class NetflixDataset(Dataset):
    def __init__(self, tokenizer):
        super().__init__()

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
14
15
16
        self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")[
            "description"
        ]
17
18
19
        self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])

        for txt in self.txt_list:
20
21
22
23
24
            encodings_dict = self.tokenizer(
                "</s>" + txt + "</s>", truncation=True, max_length=self.max_length, padding="max_length"
            )
            self.input_ids.append(torch.tensor(encodings_dict["input_ids"]))
            self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"]))
25
26
27
28
29
30

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]
31

32
33

def netflix_collator(data):
34
35
36
37
38
    return {
        "input_ids": torch.stack([x[0] for x in data]),
        "attention_mask": torch.stack([x[1] for x in data]),
        "labels": torch.stack([x[0] for x in data]),
    }