sft_dataset.py 9.17 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import copy
import random
from dataclasses import dataclass, field
18
from typing import Callable, Dict, List, Sequence, Tuple
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
22
23
24
25
26
27

import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm

from colossalai.logging import get_dist_logger

28
from .conversation import default_conversation
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30
from .utils import is_rank_0, jload

31
32
33
34
35
36
37
38
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.

Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.

Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
39
40
41
logger = get_dist_logger()

IGNORE_INDEX = -100
42
DEFAULT_EOS_TOKEN = "</s>"
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
PROMPT_DICT = {
    "prompt_input":
        ("Below is an instruction that describes a task, paired with an input that provides further context. "
         "Write a response that appropriately completes the request.\n\n"
         "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
    "prompt_no_input": ("Below is an instruction that describes a task. "
                        "Write a response that appropriately completes the request.\n\n"
                        "### Instruction:\n{instruction}\n\n### Response:"),
}


class SFTDataset(Dataset):
    """
    Dataset for sft model

    Args:
        dataset: dataset for supervised model
        tokenizer: tokenizer for supervised model
        max_length: max length of input
    """

    def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
        super().__init__()
        self.input_ids = []

        for data in tqdm(dataset, disable=not is_rank_0()):
tingfeng cao's avatar
tingfeng cao committed
69
            prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
70
71
72
73
74
75
            prompt_token = tokenizer(prompt,
                                     max_length=max_length,
                                     padding="max_length",
                                     truncation=True,
                                     return_tensors="pt")

tingfeng cao's avatar
tingfeng cao committed
76
77
            self.input_ids.append(prompt_token['input_ids'][0])
        self.labels = copy.deepcopy(self.input_ids)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
78
79

    def __len__(self):
tingfeng cao's avatar
tingfeng cao committed
80
        length = len(self.input_ids)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
81
82
83
84
85
86
        return length

    def __getitem__(self, idx):
        return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])


87
88
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
                 max_length: int) -> Dict[str, torch.Tensor]:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
89
    """Tokenize a list of strings."""
90
    tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
91
92
93
    input_ids = labels = tokenized_list["input_ids"]
    input_ids_lens = labels_lens = \
        tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
94
95
96
97
98
99
100
101
102
103
104
105
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
106
    max_length: int,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
107
108
109
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
110
    examples_tokenized, sources_tokenized = [
111
        _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
112
    ]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
113
114
115
116
117
118
119
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
                            max_length: int) -> Dict:
    """Preprocess the conversation data by tokenizing."""
    conversations = []
    intermediates = []
    for source in sources:
        header = f"{default_conversation.system}"
        conversation, intermediate = _add_speaker_and_signal(header, source)
        conversations.append(conversation)
        intermediates.append(intermediate)

    conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
    input_ids = conversations_tokenized["input_ids"]
    targets = copy.deepcopy(input_ids)

    assert len(targets) == len(intermediates)
    for target, inters in zip(targets, intermediates):
        mask = torch.zeros_like(target, dtype=torch.bool)
        for inter in inters:
            tokenized = _tokenize_fn(inter, tokenizer, max_length)

            start_idx = tokenized["input_ids"][0].size(0) - 1
            end_idx = tokenized["input_ids"][1].size(0)

            mask[start_idx:end_idx] = True
        target[~mask] = IGNORE_INDEX

    return dict(input_ids=input_ids, labels=targets)


def _add_speaker_and_signal(header: str,
                            source: List[Dict],
                            get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
    END_SIGNAL = DEFAULT_EOS_TOKEN
    conversation = header
    intermediate = []
    for sentence in source:
        from_str = sentence["from"]
        if from_str.lower() == "human":
            from_str = default_conversation.roles[0]
        elif from_str.lower() == "gpt":
            from_str = default_conversation.roles[1]
        else:
            from_str = 'unknown'

        value = from_str + ": " + sentence["value"] + END_SIGNAL
        if sentence["from"].lower() == "gpt":
            start = conversation + from_str + ": "
            end = conversation + value
            intermediate.append([start, end])
        if get_conversation:
            conversation += value
    return conversation, intermediate


Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
175
176
177
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

178
179
180
181
182
    def __init__(self,
                 data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 max_datasets_size: int = None,
                 max_length: int = 512):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
183
184
185
186
187
188
189
190
191
192
        super(SupervisedDataset, self).__init__()
        logger.info("Loading data...")
        list_data_dict = jload(data_path)
        logger.info(f"Loaded {len(list_data_dict)} examples.")

        if max_datasets_size is not None:
            logger.info(f"Limiting dataset to {max_datasets_size} examples.")
            list_data_dict = list_data_dict[:max_datasets_size]

        logger.info("Formatting inputs...")
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        if "conversations" not in list_data_dict[0]:
            prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
            sources = [
                prompt_input.format_map(example)
                if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
            ]
            targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

            if is_rank_0():
                logger.info("Tokenizing inputs... This may take some time...")

            data_dict = preprocess(sources, targets, tokenizer, max_length)
        else:
            if is_rank_0():
                logger.info("Tokenizing inputs... This may take some time...")

            sources = [conv["conversations"] for conv in list_data_dict]
            data_dict = preprocess_conversation(sources, tokenizer, max_length)

        if is_rank_0():
            logger.info("Tokenizing finish.")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
                                                    batch_first=True,
                                                    padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )