sft_dataset.py 7.88 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import copy
16
from typing import Dict, Sequence, Tuple
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
17
18
19
20

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
21
from transformers import PreTrainedTokenizer
22
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer 
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
25
26
27
28
29
30
from colossalai.logging import get_dist_logger

from .utils import is_rank_0, jload

logger = get_dist_logger()

IGNORE_INDEX = -100
PROMPT_DICT = {
31
32
33
    "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
                     "Write a response that appropriately completes the request.\n\n"
                     "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
34
35
36
37
38
39
    "prompt_no_input": ("Below is an instruction that describes a task. "
                        "Write a response that appropriately completes the request.\n\n"
                        "### Instruction:\n{instruction}\n\n### Response:"),
}


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def _preprocess(sources: Sequence[str],
                targets: Sequence[str],
                tokenizer: PreTrainedTokenizer,
                max_length: int,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocess the data by tokenizing."""
    sequences = [s + t for s, t in zip(sources, targets)]
    sequences_token = tokenizer(sequences,
                                max_length=max_length,
                                padding="max_length",
                                truncation=True,
                                return_tensors="pt")
    sources_token = tokenizer(sources,
                              max_length=max_length,
                              padding="max_length",
                              truncation=True,
                              return_tensors="pt")

    labels = copy.deepcopy(sequences_token["input_ids"])
    for i in range(labels.shape[0]):
        source_len = sources_token["attention_mask"][i].sum().item()
        pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
        if tokenizer.padding_side == "right":
            # |prompt|completion|eos|pad|
            labels[i][:source_len] = IGNORE_INDEX
        elif tokenizer.padding_side == "left":
            # |pad|prompt|completion|eos|
            labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
        else:
            raise RuntimeError()

    return sequences_token["input_ids"], labels, sequences_token["attention_mask"]


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def _preprocess_chatglm(sources: Sequence[str],
                targets: Sequence[str],
                tokenizer: PreTrainedTokenizer,
                max_length: int,
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Preprocess the data by tokenizing.
    None for attention mask, ChatGLM will calculate attention mask according to input ids
    """
  
    labels = []
    input_ids = []
    for source, target in zip(sources, targets):
        source_id = tokenizer.encode(text=source, add_special_tokens=False)
        target_id = tokenizer.encode(text=target, add_special_tokens=False)
        input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
        # truncate
        sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
        truncate_length = max(0, len(input_id) - max_length)
        input_id = input_id[truncate_length: ]
        if truncate_length == len(source_id) + 1:
            input_id = sp_token_list + input_id[1: ]
        elif truncate_length > len(source_id) + 1:
            input_id = sp_token_list + input_id[2: ]
        
        context_length = input_id.index(tokenizer.bos_token_id)
        mask_position = context_length - 1
        label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
        
        pad_len = max_length - len(input_id)
        input_id = input_id + [tokenizer.pad_token_id] * pad_len
        input_ids.append(input_id)
        labels.append(label + [IGNORE_INDEX] * pad_len)
    return torch.tensor(input_ids), torch.tensor(labels), None


Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
110
111
112
113
114
115
116
117
118
119
class SFTDataset(Dataset):
    """
    Dataset for sft model

    Args:
        dataset: dataset for supervised model
        tokenizer: tokenizer for supervised model
        max_length: max length of input
    """

120
121
122
123
124
    def __init__(self,
                 dataset: Dict,
                 tokenizer: PreTrainedTokenizer,
                 max_length: int = 512
                 ) -> None:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
125
126
127
        super().__init__()
        self.input_ids = []

128
129
130
131
132
        sources = [data["prompt"] for data in dataset]
        targets = [
            data["completion"] + tokenizer.eos_token
            for data in tqdm(dataset, disable=not is_rank_0())
        ]
133
134
135
136
137
138
        if isinstance(tokenizer, ChatGLMTokenizer):
            self.input_ids, self.labels, self.attention_mask = \
                _preprocess_chatglm(sources, targets, tokenizer, max_length)
        else:
            self.input_ids, self.labels, self.attention_mask = \
                _preprocess(sources, targets, tokenizer, max_length)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
139
140

    def __len__(self):
141
        length = self.input_ids.shape[0]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
142
143
144
        return length

    def __getitem__(self, idx):
145
146
147
148
149
150
151
        if self.attention_mask is not None:
            return dict(input_ids=self.input_ids[idx],
                        labels=self.labels[idx],
                        attention_mask=self.attention_mask[idx])
        else:
            return dict(input_ids=self.input_ids[idx],
                        labels=self.labels[idx])
152
153


Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
154
155
156
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

157
158
    def __init__(self,
                 data_path: str,
159
                 tokenizer: PreTrainedTokenizer,
160
161
                 max_datasets_size: int = None,
                 max_length: int = 512):
162
        super().__init__()
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
163
164
165
166
167
168
169
170
171
        logger.info("Loading data...")
        list_data_dict = jload(data_path)
        logger.info(f"Loaded {len(list_data_dict)} examples.")

        if max_datasets_size is not None:
            logger.info(f"Limiting dataset to {max_datasets_size} examples.")
            list_data_dict = list_data_dict[:max_datasets_size]

        logger.info("Formatting inputs...")
172
173
174
175
176
177
178
179
180
181
182
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
        sources = [
            prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
            for example in list_data_dict
        ]
        targets = [
            example['output'] + tokenizer.eos_token
            for example in list_data_dict
        ]

        logger.info("Tokenizing inputs... This may take some time...")
183
184
185
186
187
188
        if isinstance(tokenizer, ChatGLMTokenizer):
            self.input_ids, self.labels, self.attention_mask = \
                _preprocess_chatglm(sources, targets, tokenizer, max_length)
        else:
            self.input_ids, self.labels, self.attention_mask = \
                _preprocess(sources, targets, tokenizer, max_length)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
189
190

    def __len__(self):
191
192
        length = self.input_ids.shape[0]
        return length
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
193

194
    def __getitem__(self, idx):
195
196
197
198
199
200
201
        if self.attention_mask is not None:
            return dict(input_ids=self.input_ids[idx],
                        labels=self.labels[idx],
                        attention_mask=self.attention_mask[idx])
        else:
            return dict(input_ids=self.input_ids[idx],
                        labels=self.labels[idx])