aligner.py 9.85 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import partial
luopl's avatar
luopl committed
17
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
chenych's avatar
chenych committed
18

luopl's avatar
luopl committed
19
from ..extras import logging
chenych's avatar
chenych committed
20
21
22
23
24
25
26
27
from .data_utils import Role


if TYPE_CHECKING:
    from datasets import Dataset, IterableDataset
    from transformers import Seq2SeqTrainingArguments

    from ..hparams import DataArguments
luopl's avatar
luopl committed
28
    from .mm_plugin import ImageInput, VideoInput
chenych's avatar
chenych committed
29
30
31
    from .parser import DatasetAttr


luopl's avatar
luopl committed
32
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
33
34


luopl's avatar
luopl committed
35
def _convert_images(
luopl's avatar
luopl committed
36
    images: Union["ImageInput", Sequence["ImageInput"]],
luopl's avatar
luopl committed
37
38
39
    dataset_attr: "DatasetAttr",
    data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
chenych's avatar
chenych committed
40
41
42
    r"""
    Optionally concatenates image path to dataset dir when loading from local disk.
    """
luopl's avatar
luopl committed
43
44
45
    if not isinstance(images, list):
        images = [images]
    elif len(images) == 0:
luopl's avatar
luopl committed
46
        return None
luopl's avatar
luopl committed
47
48
    else:
        images = images[:]
luopl's avatar
luopl committed
49
50
51

    if dataset_attr.load_from in ["script", "file"]:
        for i in range(len(images)):
luopl's avatar
luopl committed
52
53
            if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
                images[i] = os.path.join(data_args.image_dir, images[i])
luopl's avatar
luopl committed
54
55
56
57
58

    return images


def _convert_videos(
luopl's avatar
luopl committed
59
    videos: Union["VideoInput", Sequence["VideoInput"]],
luopl's avatar
luopl committed
60
61
62
63
64
65
    dataset_attr: "DatasetAttr",
    data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
    r"""
    Optionally concatenates video path to dataset dir when loading from local disk.
    """
luopl's avatar
luopl committed
66
67
68
    if not isinstance(videos, list):
        videos = [videos]
    elif len(videos) == 0:
luopl's avatar
luopl committed
69
        return None
luopl's avatar
luopl committed
70
71
    else:
        videos = videos[:]
luopl's avatar
luopl committed
72

chenych's avatar
chenych committed
73
    if dataset_attr.load_from in ["script", "file"]:
luopl's avatar
luopl committed
74
        for i in range(len(videos)):
luopl's avatar
luopl committed
75
76
            if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
                videos[i] = os.path.join(data_args.image_dir, videos[i])
chenych's avatar
chenych committed
77

luopl's avatar
luopl committed
78
    return videos
chenych's avatar
chenych committed
79
80
81


def convert_alpaca(
luopl's avatar
luopl committed
82
83
84
85
    example: Dict[str, Any],
    dataset_attr: "DatasetAttr",
    data_args: "DataArguments",
) -> Dict[str, Any]:
chenych's avatar
chenych committed
86
87
88
    r"""
    Converts alpaca format dataset to the standard format.
    """
luopl's avatar
luopl committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    prompt = []
    if dataset_attr.history and isinstance(example[dataset_attr.history], list):
        for old_prompt, old_response in example[dataset_attr.history]:
            prompt.append({"role": Role.USER.value, "content": old_prompt})
            prompt.append({"role": Role.ASSISTANT.value, "content": old_response})

    query = []
    if dataset_attr.prompt and example[dataset_attr.prompt]:
        query.append(example[dataset_attr.prompt])

    if dataset_attr.query and example[dataset_attr.query]:
        query.append(example[dataset_attr.query])

    prompt.append({"role": Role.USER.value, "content": "\n".join(query)})  # "prompt\nquery"

    if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool):  # kto example
        response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
        if example[dataset_attr.kto_tag]:
            response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
        else:
            response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
    elif (
        dataset_attr.ranking
        and isinstance(example[dataset_attr.chosen], str)
        and isinstance(example[dataset_attr.rejected], str)
    ):  # pairwise example
        response = [
            {"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
            {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
        ]
    elif dataset_attr.response and isinstance(example[dataset_attr.response], str):  # normal example
        response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
    else:  # unsupervised
        response = []

chenych's avatar
chenych committed
124
    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
luopl's avatar
luopl committed
125
126
127
128
129
130
131
132
133
134
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
    output = {
        "_prompt": prompt,
        "_response": response,
        "_system": example[dataset_attr.system] if dataset_attr.system else "",
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
    }
    return output
chenych's avatar
chenych committed
135
136
137


def convert_sharegpt(
luopl's avatar
luopl committed
138
139
140
141
    example: Dict[str, Any],
    dataset_attr: "DatasetAttr",
    data_args: "DataArguments",
) -> Dict[str, Any]:
chenych's avatar
chenych committed
142
143
144
145
146
147
148
149
150
151
152
153
154
    r"""
    Converts sharegpt format dataset to the standard format.
    """
    tag_mapping = {
        dataset_attr.user_tag: Role.USER.value,
        dataset_attr.assistant_tag: Role.ASSISTANT.value,
        dataset_attr.observation_tag: Role.OBSERVATION.value,
        dataset_attr.function_tag: Role.FUNCTION.value,
        dataset_attr.system_tag: Role.SYSTEM.value,
    }
    odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
    even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
    accept_tags = (odd_tags, even_tags)
luopl's avatar
luopl committed
155
156
157
158
159
160
161
162
163
164
    messages = example[dataset_attr.messages]
    if (
        dataset_attr.system_tag
        and len(messages) != 0
        and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
    ):
        system = messages[0][dataset_attr.content_tag]
        messages = messages[1:]
    else:
        system = example[dataset_attr.system] if dataset_attr.system else ""
chenych's avatar
chenych committed
165

luopl's avatar
luopl committed
166
167
168
169
    aligned_messages = []
    broken_data = False
    for turn_idx, message in enumerate(messages):
        if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
luopl's avatar
luopl committed
170
            logger.warning_rank0(f"Invalid role tag in {messages}.")
luopl's avatar
luopl committed
171
            broken_data = True
chenych's avatar
chenych committed
172

luopl's avatar
luopl committed
173
174
175
        aligned_messages.append(
            {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
        )
chenych's avatar
chenych committed
176

luopl's avatar
luopl committed
177
178
179
    if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
        dataset_attr.ranking and len(aligned_messages) % 2 == 0
    ):
luopl's avatar
luopl committed
180
        logger.warning_rank0(f"Invalid message count in {messages}.")
luopl's avatar
luopl committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        broken_data = True

    if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool):  # kto example
        prompt = aligned_messages[:-1]
        response = aligned_messages[-1:]
        if example[dataset_attr.kto_tag]:
            response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
        else:
            response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
    elif (
        dataset_attr.ranking
        and isinstance(example[dataset_attr.chosen], dict)
        and isinstance(example[dataset_attr.rejected], dict)
    ):  # pairwise example
        chosen = example[dataset_attr.chosen]
        rejected = example[dataset_attr.rejected]
        if (
            chosen[dataset_attr.role_tag] not in accept_tags[-1]
            or rejected[dataset_attr.role_tag] not in accept_tags[-1]
chenych's avatar
chenych committed
200
        ):
luopl's avatar
luopl committed
201
            logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
chenych's avatar
chenych committed
202
203
            broken_data = True

luopl's avatar
luopl committed
204
205
206
207
208
209
210
211
212
213
        prompt = aligned_messages
        response = [
            {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
            {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
        ]
    else:  # normal example
        prompt = aligned_messages[:-1]
        response = aligned_messages[-1:]

    if broken_data:
luopl's avatar
luopl committed
214
        logger.warning_rank0("Skipping this abnormal example.")
luopl's avatar
luopl committed
215
216
217
218
219
220
221
222
223
224
225
226
227
        prompt, response = [], []

    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
    output = {
        "_prompt": prompt,
        "_response": response,
        "_system": system,
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
    }
    return output
chenych's avatar
chenych committed
228
229
230
231
232
233
234
235
236
237


def align_dataset(
    dataset: Union["Dataset", "IterableDataset"],
    dataset_attr: "DatasetAttr",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
    r"""
    Aligned dataset:
luopl's avatar
luopl committed
238
239
240
241
242
243
        _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
        _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
        _system: "..."
        _tools: "...",
        _images: [],
        _videos: [],
chenych's avatar
chenych committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    """
    if dataset_attr.formatting == "alpaca":
        convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
    else:
        convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)

    column_names = list(next(iter(dataset)).keys())
    kwargs = {}
    if not data_args.streaming:
        kwargs = dict(
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
            desc="Converting format of dataset",
        )

    return dataset.map(
        convert_func,
luopl's avatar
luopl committed
261
        batched=False,
chenych's avatar
chenych committed
262
263
264
        remove_columns=column_names,
        **kwargs,
    )