reward_dataset.py 3.23 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
from typing import Callable

from torch.utils.data import Dataset
from tqdm import tqdm

from .utils import is_rank_0


9
# Dahoas/rm-static
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13
14
15
16
17
18
19
20
21
22
class RmStaticDataset(Dataset):
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
        special_token: special token at the end of sentence
    """

    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
        super().__init__()
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.end_token = tokenizer.eos_token if special_token is None else special_token

        chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
        chosen_token = tokenizer(
            chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}

        reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
        reject_token = tokenizer(
            reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
36
37

    def __len__(self):
38
        length = self.chosen["input_ids"].shape[0]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40
41
        return length

    def __getitem__(self, idx):
42
43
44
45
46
47
        return (
            self.chosen["input_ids"][idx],
            self.chosen["attention_mask"][idx],
            self.reject["input_ids"][idx],
            self.reject["attention_mask"][idx],
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
    """
    Dataset for reward model

    Args:
        dataset: dataset for reward model
        tokenizer: tokenizer for reward model
        max_length: max length of input
        special_token: special token at the end of sentence
    """

    def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
        super().__init__()
64
65
66
67
68
69
70
71
72
73
74
75
76
        self.end_token = tokenizer.eos_token if special_token is None else special_token

        chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
        chosen_token = tokenizer(
            chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}

        reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
        reject_token = tokenizer(
            reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
77
78

    def __len__(self):
79
        length = self.chosen["input_ids"].shape[0]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
80
81
82
        return length

    def __getitem__(self, idx):
83
84
85
86
87
88
        return (
            self.chosen["input_ids"][idx],
            self.chosen["attention_mask"][idx],
            self.reject["input_ids"][idx],
            self.reject["attention_mask"][idx],
        )