"LICENSE" did not exist on "daf734774d86163475c9f708eaacb44f04908143"
dataset.py 2.89 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import time
import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from tokenizer_GPTJ import get_transformer_autotokenizer
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from typing import Optional, Dict, Sequence
import io
import utils
import copy

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


class Dataset:
    def __init__(
        self,
        dataset_path,
        batch_size=1,
        pad_val=1,
        pad_max=196,
        total_count_override=None,
        perf_count_override=None,
    ):
        print("Constructing QSL")

        self.dataset = "cnn_dailymail"
        self.model_name = "EleutherAI/gpt-j-6B"
        self.dataset_path = dataset_path
        self.batch_size = batch_size
        self.pad_val = pad_val
        self.pad_max = pad_max

        self.tokenizer = get_transformer_autotokenizer(self.model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.list_data_dict = utils.jload(self.dataset_path)

        prompt_input, prompt_no_input = (
            PROMPT_DICT["prompt_input"],
            PROMPT_DICT["prompt_no_input"],
        )
        self.sources = [
            prompt_input.format_map(example) for example in self.list_data_dict
        ]
        self.targets = [
            f"{example['output']}" for example in self.list_data_dict]

        self.source_encoded_input_ids, self.source_encoded_attn_masks = (
            self.encode_samples()
        )

        self.count = total_count_override or len(self.sources)
        self.perf_count = perf_count_override or self.count

    def encode_samples(self):
        print("Encoding Samples")

        total_samples = len(self.sources)

        source_encoded_input_ids = []
        source_encoded_attn_masks = []

        for i in range(total_samples):
            source_encoded = self.tokenizer(
                self.sources[i],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1919,
            )
            source_encoded_input_ids.append(source_encoded.input_ids)
            source_encoded_attn_masks.append(source_encoded.attention_mask)

        return source_encoded_input_ids, source_encoded_attn_masks

    def LoadSamplesToRam(self, sample_list):
        pass

    def UnloadSamplesFromRam(self, sample_list):
        pass

    def __del__(self):
        print("Finished destroying QSL.")