reward_dataset.py 3.97 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
4
5
from typing import Callable

from torch.utils.data import Dataset
from tqdm import tqdm

BlueRum's avatar
BlueRum committed
6
7
from .utils import is_rank_0

8
9
# Dahaos/rm-static
class RmStaticDataset(Dataset):
ver217's avatar
ver217 committed
10
11
12
13
14
15
16
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
17
        special_token: special token at the end of sentence
ver217's avatar
ver217 committed
18
19
    """

20
    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
ver217's avatar
ver217 committed
21
22
23
        super().__init__()
        self.chosen = []
        self.reject = []
24
25
26
27
        if special_token is None:
            self.end_token = tokenizer.eos_token
        else:
            self.end_token = special_token
BlueRum's avatar
BlueRum committed
28
        for data in tqdm(dataset, disable=not is_rank_0()):
ver217's avatar
ver217 committed
29
30
            prompt = data['prompt']

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
            chosen = prompt + data['chosen'] + self.end_token
            chosen_token = tokenizer(chosen,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.chosen.append({
                "input_ids": chosen_token['input_ids'],
                "attention_mask": chosen_token['attention_mask']
            })

            reject = prompt + data['rejected'] + self.end_token
            reject_token = tokenizer(reject,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.reject.append({
                "input_ids": reject_token['input_ids'],
                "attention_mask": reject_token['attention_mask']
            })

    def __len__(self):
        length = len(self.chosen)
        return length

    def __getitem__(self, idx):
        return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
            "input_ids"], self.reject[idx]["attention_mask"]

# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
        special_token: special token at the end of sentence
    """
    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
        super().__init__()
        self.chosen = []
        self.reject = []
        if special_token is None:
            self.end_token = tokenizer.eos_token
        else:
            self.end_token = special_token
        for data in tqdm(dataset, disable=not is_rank_0()):
            chosen = data['chosen'] + self.end_token
ver217's avatar
ver217 committed
82
83
84
85
86
87
88
89
90
91
            chosen_token = tokenizer(chosen,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.chosen.append({
                "input_ids": chosen_token['input_ids'],
                "attention_mask": chosen_token['attention_mask']
            })

92
            reject = data['rejected'] + self.end_token
ver217's avatar
ver217 committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            reject_token = tokenizer(reject,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")
            self.reject.append({
                "input_ids": reject_token['input_ids'],
                "attention_mask": reject_token['attention_mask']
            })

    def __len__(self):
        length = len(self.chosen)
        return length

    def __getitem__(self, idx):
        return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
            "input_ids"], self.reject[idx]["attention_mask"]