reward_dataset.py 3.86 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
from typing import Callable

from torch.utils.data import Dataset
from tqdm import tqdm

from .utils import is_rank_0


9
# Dahoas/rm-static
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13
14
15
16
17
18
19
20
21
22
class RmStaticDataset(Dataset):
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
        special_token: special token at the end of sentence
    """

    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
        super().__init__()
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        self.end_token = tokenizer.eos_token \
            if special_token is None else special_token

        chosen = [
            data["prompt"] + data["chosen"] + self.end_token
            for data in tqdm(dataset, disable=not is_rank_0())
        ]
        chosen_token = tokenizer(chosen,
                                 max_length=max_length,
                                 padding="max_length",
                                 truncation=True,
                                 return_tensors="pt")
        self.chosen = {
            "input_ids": chosen_token["input_ids"],
            "attention_mask": chosen_token["attention_mask"]
        }

        reject = [
            data["prompt"] + data["rejected"] + self.end_token
            for data in tqdm(dataset, disable=not is_rank_0())
        ]
        reject_token = tokenizer(reject,
                                 max_length=max_length,
                                 padding="max_length",
                                 truncation=True,
                                 return_tensors="pt")
        self.reject = {
            "input_ids": reject_token["input_ids"],
            "attention_mask": reject_token["attention_mask"]
        }
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
53
54

    def __len__(self):
55
        length = self.chosen["input_ids"].shape[0]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
56
57
58
        return length

    def __getitem__(self, idx):
59
60
        return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
            self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
        special_token: special token at the end of sentence
    """

    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
        super().__init__()
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.end_token = tokenizer.eos_token \
            if special_token is None else special_token

        chosen = [
            data["chosen"] + self.end_token
            for data in tqdm(dataset, disable=not is_rank_0())
        ]
        chosen_token = tokenizer(chosen,
                                 max_length=max_length,
                                 padding="max_length",
                                 truncation=True,
                                 return_tensors="pt")
        self.chosen = {
            "input_ids": chosen_token["input_ids"],
            "attention_mask": chosen_token["attention_mask"]
        }

        reject = [
            data["rejected"] + self.end_token
            for data in tqdm(dataset, disable=not is_rank_0())
        ]
        reject_token = tokenizer(reject,
                                 max_length=max_length,
                                 padding="max_length",
                                 truncation=True,
                                 return_tensors="pt")
        self.reject = {
            "input_ids": reject_token["input_ids"],
            "attention_mask": reject_token["attention_mask"]
        }
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
107
108

    def __len__(self):
109
        length = self.chosen["input_ids"].shape[0]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
110
111
112
        return length

    def __getitem__(self, idx):
113
114
        return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
            self.reject["input_ids"][idx], self.reject["attention_mask"][idx]