data_utils.py 10.8 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import copy
import random
from functools import partial
from typing import Callable, Dict, List, Optional

import torch
from datasets import DatasetDict, IterableDatasetDict, load_dataset
from torch import LongTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer


def make_data_block(
    samples: Dict[str, List[str]],
    prompt_col_name: str,
    label_col_name: str,
    tokenizer: PreTrainedTokenizer,
    preprocess_fn: Optional[Callable] = None,
    sample_max_len: int = 1024,
    block_max_len: int = 2048,
    add_eos_token: bool = False,
    truncate_prompt: bool = True,
    merge_prompt_label: bool = False,
) -> Dict[str, List[LongTensor]]:
    """A simple implementation of text generation oriented smart batching to maximize VRAM usage when evaluation

    :param samples: Dict[str, List[str]], samples that used to make data blocks
    :param prompt_col_name: str, name of the key in samples whose value stores prompt
    :param label_col_name: str, name of the key in samples whose value stores label
    :param tokenizer: transformers.PretrainedTokenizer, tokenizer that used to tokenize samples
    :param preprocess_fn: Optional[Callable], optional function that used to preprocess samples such as
        refactor the data structure of samples, note the output of this function must be a dict whose keys
        at least contains `prompt_col_name` and `label_col_name`
    :param sample_max_len: int, defaults to 1024, max tokens number of each sample (before padding)
    :param block_max_len: int, defaults to 2048, max tokens number of each data block (after padding)
    :param add_eos_token: bool, defaults to False, whether add eos_token or not to the label
    :param truncate_prompt: bool, defaults to True, whether to truncate prompt if the sample's total tokens
        number exceeds `sample_max_len`, if not, will truncate label and drop this sample when all tokens
        in label are truncated
    :param merge_prompt_label: bool, defaults to False, will merge label into prompt if set to True, usually
        this only required when doing language modeling task
    :return: Dict[str, List[torch.LongTensor]], a dict whose keys are `input_ids`, `attention_mask` and
        `label` and values are a list of torch.LongTensor
    """
    if preprocess_fn:
        samples = preprocess_fn(samples)

    prompts = samples[prompt_col_name]
    labels = samples[label_col_name]

    # tokenize samples
    tokenized_prompts = tokenizer(prompts, truncation=False)["input_ids"]
    tokenized_labels = tokenizer(labels, truncation=False)["input_ids"]

    # filter tokenized samples by length
    dropped_indices = []
    for idx, (tokenized_prompt, tokenized_label) in enumerate(zip(tokenized_prompts, tokenized_labels)):
        if add_eos_token:
            tokenized_label += [tokenizer.eos_token_id]
        len_prompt = len(tokenized_prompt)
        len_label = len(tokenized_label)
        exceed_len = len_prompt + len_label - sample_max_len
        if exceed_len > 0:
            if truncate_prompt:
                tokenized_prompt = tokenized_prompt[exceed_len:]
            else:
                tokenized_label = tokenized_label[:-exceed_len]
        tokenized_prompts[idx] = tokenized_prompt
        tokenized_labels[idx] = tokenized_label
        if not tokenized_label:
            dropped_indices.append(idx)

    # make data blocks of samples
    tokenized_samples = sorted(
        [(p, l) for idx, (p, l) in enumerate(zip(tokenized_prompts, tokenized_labels)) if idx not in dropped_indices],
        key=lambda x: (len(x[0]) + len(x[1])) if merge_prompt_label else len(x[0]),
    )
    sample_blocks = []
    sample_block = []
    blk_max_len = 0
    blk_total_len = 0
    for tokenized_sample in tokenized_samples:
        prompt_ids, label_ids = tokenized_sample
        ori_sample_len = len(prompt_ids)
        if merge_prompt_label:
            ori_sample_len += len(label_ids)
        if ori_sample_len <= blk_max_len:
            additional_len = blk_max_len
            sample_len = blk_max_len
        else:
            additional_len = len(sample_block) * (ori_sample_len - blk_max_len) + ori_sample_len
            sample_len = ori_sample_len

        if blk_total_len + additional_len > block_max_len:
            sample_blocks.append((copy.copy(sample_block), blk_max_len))
            sample_block = []
            blk_max_len = 0
            blk_total_len = 0
            sample_len = ori_sample_len
            additional_len = ori_sample_len

        sample_block.append(tokenized_sample)
        blk_max_len = max(blk_max_len, sample_len)
        blk_total_len += additional_len

    if sample_block:
        sample_blocks.append((copy.copy(sample_block), blk_max_len))
    del sample_block
    del blk_max_len
    del blk_total_len

    new_samples = {"input_ids": [], "attention_mask": [], "labels": []}
    # padding each data block internally
    for block, blk_max_len in sample_blocks:
        input_ids = []
        attention_mask = []
        label_ids = []
        label_max_len = max([len(sample[1]) for sample in block])

        for sample in block:
            tokenized_prompt, tokenized_label = sample
            sample_len = len(tokenized_prompt)
            if merge_prompt_label:
                sample_len += len(tokenized_label)
            pad_num = blk_max_len - sample_len
            if merge_prompt_label:
                input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt + tokenized_label)
                label_ids.append([-100] * (pad_num + len(tokenized_prompt)) + tokenized_label)
            else:
                input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt)
                label_ids.append([-100] * (label_max_len - len(tokenized_label)) + tokenized_label)
            attention_mask.append([0] * pad_num + [1] * sample_len)

        new_samples["input_ids"].append(input_ids)
        new_samples["attention_mask"].append(attention_mask)
        new_samples["labels"].append(label_ids)

    return new_samples


def collate_data(blocks: List[Dict[str, List[List[int]]]], pad_token_id: int) -> Dict[str, LongTensor]:
    def pad_block(block, pads):
        return torch.cat((pads.to(block.device), block), dim=-1)

    input_ids_blocks = [LongTensor(block["input_ids"]) for block in blocks]
    attention_mask_blocks = [LongTensor(block["attention_mask"]) for block in blocks]
    label_blocks = [LongTensor(block["labels"]) for block in blocks]

    bsz = len(blocks)
    inp_max_len = max([block.size(-1) for block in input_ids_blocks])
    label_max_len = max([block.size(-1) for block in label_blocks])

    for i in range(bsz):
        block_bsz, block_inp_len = input_ids_blocks[i].shape
        block_label_len = label_blocks[i].shape[-1]
        pad_num = inp_max_len - block_inp_len
        if pad_num > 0:
            input_ids_blocks[i] = pad_block(input_ids_blocks[i], torch.ones((block_bsz, pad_num)) * pad_token_id)
            attention_mask_blocks[i] = pad_block(attention_mask_blocks[i], torch.zeros((block_bsz, pad_num)))
        label_pad_num = label_max_len - block_label_len
        if label_pad_num > 0:
            label_blocks[i] = pad_block(label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100)

    return {
        "input_ids": torch.cat(input_ids_blocks, dim=0).long(),
        "attention_mask": torch.cat(attention_mask_blocks, dim=0).long(),
        "labels": torch.cat(label_blocks, dim=0).long(),
    }


def get_dataloader(
    data_path_or_name: str,
    prompt_col_name: str,
    label_col_name: str,
    tokenizer: PreTrainedTokenizer,
    load_fn: Optional[Callable] = None,
    preprocess_fn: Optional[Callable] = None,
    num_samples: int = 128,
    sample_max_len: int = 1024,
    block_max_len: int = 2048,
    add_eos_token: bool = False,
    truncate_prompt: bool = True,
    merge_prompt_label: bool = False,
    load_fn_kwargs: Optional[dict] = None,
    preprocess_fn_kwargs: Optional[dict] = None,
    **kwargs,
) -> DataLoader:
    """load dataset and build dataloader

    :param data_path_or_name: str, dataset name in hf-hub or local file path
    :param prompt_col_name: str, see `make_data_block`
    :param label_col_name: str, see `make_data_block`
    :param tokenizer: str, see `make_data_block`
    :param load_fn: Optional[Callable], defaults to None, function used to load dataset, if not specified,
        use `datasets.load_dataset`
    :param preprocess_fn: Optional[Callable], see `make_data_block`
    :param num_samples: int, defaults to 128, total samples used to evaluation
    :param sample_max_len: int, see `make_data_block`
    :param block_max_len: int, see `make_data_block`
    :param add_eos_token: bool, see `make_data_block`
    :param truncate_prompt: bool, see `make_data_block`
    :param merge_prompt_label: bool, see `make_data_block`
    :param load_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
        for `load_fn` or `datasets.load_dataset`
    :param preprocess_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
        for `preprocess_fn`
    :param kwargs: additional keyword arguments will be passed to torch's `DataLoader` initialization,
        note values of `batch_size`, `shuffle` and `collate_fn` will always be overridden to fixed value
    :return: torch.utils.data.DataLoader
    """

    if not load_fn_kwargs:
        load_fn_kwargs = {}
    if not preprocess_fn_kwargs:
        preprocess_fn_kwargs = {}

    if load_fn:
        ds = load_fn(data_path_or_name, **load_fn_kwargs)
    else:
        ds = load_dataset(data_path_or_name, **load_fn_kwargs)
    if isinstance(ds, (DatasetDict, IterableDatasetDict)):
        if "evaluation" in ds:
            ds = ds["evaluation"]
        elif "test" in ds:
            ds = ds["test"]
        else:
            ds = ds["train"]

    ds = ds.select(
        indices=random.sample(range(len(ds)), min(len(ds), num_samples)),
        keep_in_memory=True,
    )
    ds = ds.map(
        make_data_block,
        batched=True,
        batch_size=len(ds),
        num_proc=1,
        remove_columns=ds.column_names,
        keep_in_memory=True,
        load_from_cache_file=False,
        fn_kwargs={
            "prompt_col_name": prompt_col_name,
            "label_col_name": label_col_name,
            "tokenizer": tokenizer,
            "preprocess_fn": partial(preprocess_fn, **preprocess_fn_kwargs),
            "sample_max_len": sample_max_len,
            "block_max_len": block_max_len,
            "add_eos_token": add_eos_token,
            "truncate_prompt": truncate_prompt,
            "merge_prompt_label": merge_prompt_label,
        },
    )

    # override some arguments' values in kwargs despite user specified
    kwargs["batch_size"] = 1
    kwargs["shuffle"] = False
    kwargs["collate_fn"] = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
    dl = DataLoader(ds, **kwargs)

    return dl


__all__ = ["make_data_block", "collate_data", "get_dataloader"]