collator.py 13.7 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Literal, Optional
chenych's avatar
chenych committed
20

chenych's avatar
chenych committed
21
import numpy as np
chenych's avatar
chenych committed
22
import torch
luopl's avatar
luopl committed
23
import torch.nn.functional as F
chenych's avatar
chenych committed
24
25
from transformers import DataCollatorForSeq2Seq

chenych's avatar
chenych committed
26
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
chenych's avatar
chenych committed
27
from ..extras.misc import get_current_device
luopl's avatar
luopl committed
28
29
30
31
32
33
from ..extras.packages import is_pillow_available


if is_pillow_available():
    from PIL import Image

chenych's avatar
chenych committed
34

luopl's avatar
luopl committed
35
36
37
38
39
40
if TYPE_CHECKING:
    from transformers import ProcessorMixin

    from .template import Template


chenych's avatar
chenych committed
41
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
chenych's avatar
chenych committed
42
43
44
45
    r"""Expand 2d attention mask to 4d attention mask.

    Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
    handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
chenych's avatar
chenych committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    e.g.
    ```python
    # input
    [[1, 1, 2, 2, 2, 0]]
    # output
    [
        [
            [
                [o, x, x, x, x, x],
                [o, o, x, x, x, x],
                [x, x, o, x, x, x],
                [x, x, o, o, x, x],
                [x, x, o, o, o, x],
                [x, x, x, x, x, x],
            ]
        ]
    ]
    ```
    where `o` equals to `0.0`, `x` equals to `min_dtype`.
    """
chenych's avatar
chenych committed
67
68
69
70
71
72
73
74
    _, seq_len = attention_mask_with_indices.size()

    # Move to compute device if the source is CPU.
    source_device = attention_mask_with_indices.device
    compute_device = get_current_device() if source_device.type == "cpu" else source_device
    if compute_device != source_device:
        attention_mask_with_indices = attention_mask_with_indices.to(compute_device)

chenych's avatar
chenych committed
75
    min_dtype = torch.finfo(dtype).min
chenych's avatar
chenych committed
76
77
78
79
80
81
82
83
84
85
    zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)

    # Create a non-padding mask.
    non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
    # Create indices for comparison.
    indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2)  # [bsz, 1, 1, seq_len]
    indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3)  # [bsz, 1, seq_len, 1]
    # Create a lower triangular mask.
    tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
    attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
chenych's avatar
chenych committed
86
    # Invert the attention mask.
chenych's avatar
chenych committed
87
88
89
90
91
    attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)

    # Move back to original device if needed.
    if compute_device != source_device:
        attention_mask_4d = attention_mask_4d.to(source_device)
chenych's avatar
chenych committed
92
93
94
95
    return attention_mask_4d


@dataclass
luopl's avatar
luopl committed
96
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chenych's avatar
chenych committed
97
    r"""Data collator that supports VLMs.
luopl's avatar
luopl committed
98

chenych's avatar
chenych committed
99
    Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
luopl's avatar
luopl committed
100
101
102
103
104
    """

    template: Optional["Template"] = None
    processor: Optional["ProcessorMixin"] = None

luopl's avatar
luopl committed
105
106
107
108
    def __post_init__(self):
        if self.template is None:
            raise ValueError("Template is required for MultiModalDataCollator.")

chenych's avatar
chenych committed
109
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
110
111
        batch_images, batch_videos, batch_audios = [], [], []
        batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
luopl's avatar
luopl committed
112
113
114
        for feature in features:
            images = feature.pop("images", None) or []
            videos = feature.pop("videos", None) or []
chenych's avatar
chenych committed
115
            audios = feature.pop("audios", None) or []
luopl's avatar
luopl committed
116
117
            batch_images.extend(images)
            batch_videos.extend(videos)
chenych's avatar
chenych committed
118
            batch_audios.extend(audios)
luopl's avatar
luopl committed
119
120
            batch_imglens.append(len(images))
            batch_vidlens.append(len(videos))
chenych's avatar
chenych committed
121
            batch_audlens.append(len(audios))
luopl's avatar
luopl committed
122
            batch_input_ids.append(feature["input_ids"])
luopl's avatar
luopl committed
123

chenych's avatar
chenych committed
124
        fake_input_ids = []
luopl's avatar
luopl committed
125
        if (
chenych's avatar
chenych committed
126
            self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
luopl's avatar
luopl committed
127
128
129
        ):  # avoid process hanging in zero3/fsdp case
            fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
            fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
chenych's avatar
chenych committed
130
131
            fake_messages = self.template.mm_plugin.process_messages(
                fake_messages, fake_images, [], [], self.processor
luopl's avatar
luopl committed
132
            )
chenych's avatar
chenych committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
            _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
                _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
            )
            fake_input_ids.extend(_fake_input_ids)
            batch_images = fake_images
            batch_imglens[0] = 1

        if (
            self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
        ):  # avoid process hanging in zero3/fsdp case
            fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
            fake_audios = [np.zeros(1600)]
            fake_messages = self.template.mm_plugin.process_messages(
                fake_messages, [], [], fake_audios, self.processor
            )
            _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
            _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
                _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
            )
            fake_input_ids.extend(_fake_input_ids)
            batch_audios = fake_audios
            batch_audlens[0] = 1

        if len(fake_input_ids) != 0:
luopl's avatar
luopl committed
158
159
160
161
162
163
164
165
166
167
168
            if self.tokenizer.padding_side == "right":
                features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
                features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
                features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
            else:
                features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
                features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
                features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]

            batch_input_ids[0] = features[0]["input_ids"]

luopl's avatar
luopl committed
169
        mm_inputs = self.template.mm_plugin.get_mm_inputs(
chenych's avatar
chenych committed
170
171
172
173
174
175
176
177
            batch_images,
            batch_videos,
            batch_audios,
            batch_imglens,
            batch_vidlens,
            batch_audlens,
            batch_input_ids,
            self.processor,
luopl's avatar
luopl committed
178
179
180
181
182
183
        )
        if "token_type_ids" in mm_inputs:
            token_type_ids = mm_inputs.pop("token_type_ids")
            for i, feature in enumerate(features):
                feature["token_type_ids"] = token_type_ids[i]

chenych's avatar
chenych committed
184
        features: dict[str, torch.Tensor] = super().__call__(features)
luopl's avatar
luopl committed
185
186

        if self.model is not None and hasattr(self.model, "get_rope_index"):  # for qwen2vl mrope
chenych's avatar
chenych committed
187
188
189
190
191
192
            rope_index_kwargs = {
                "input_ids": features["input_ids"],
                "image_grid_thw": mm_inputs.get("image_grid_thw"),
                "video_grid_thw": mm_inputs.get("video_grid_thw"),
                "attention_mask": features["attention_mask"],
            }
chenych's avatar
chenych committed
193
            if "second_per_grid_ts" in mm_inputs:  # for qwen2vl
chenych's avatar
chenych committed
194
                rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
chenych's avatar
chenych committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            if "video_second_per_grid" in mm_inputs:  # for qwen2omni
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")

            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker":  # for qwen2omni
                feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
                if feature_attention_mask is not None:
                    audio_feature_lengths = torch.sum(
                        feature_attention_mask, dim=1
                    )  # FIXME need to get video image lengths
                    rope_index_kwargs["audio_seqlens"] = audio_feature_lengths  # prepare for input

                delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
                # avoid conflict
                new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
                features["position_ids"], features["rope_deltas"] = (
                    new_position_ids.clone(),
                    rope_deltas - delta0,
                )  # avoid inplace operation FIXME
            else:  # for qwen2vl
                features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
luopl's avatar
luopl committed
215
216
217
218
219
220
221

        if "cross_attention_mask" in mm_inputs:  # for mllama inputs when pad_to_multiple_of is enabled
            cross_attention_mask = mm_inputs.pop("cross_attention_mask")
            seq_len = features["input_ids"].size(1)
            orig_len = cross_attention_mask.size(1)
            mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))

luopl's avatar
luopl committed
222
        features.update(mm_inputs)
luopl's avatar
luopl committed
223

luopl's avatar
luopl committed
224
225
226
227
228
        if "image_bound" in features:  # for minicpmv inputs
            bsz, seq_length = features["input_ids"].shape
            features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
            return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}

luopl's avatar
luopl committed
229
230
231
232
233
        return features


@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
234
    r"""Data collator for 4d attention mask."""
chenych's avatar
chenych committed
235
236
237
238
239

    block_diag_attn: bool = False
    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
    compute_dtype: "torch.dtype" = torch.float32

chenych's avatar
chenych committed
240
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
241
242
243
244
        features = super().__call__(features)
        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)

luopl's avatar
luopl committed
245
246
247
248
        for key, value in features.items():  # cast data dtype for paligemma
            if torch.is_tensor(value) and torch.is_floating_point(value):
                features[key] = value.to(self.compute_dtype)

chenych's avatar
chenych committed
249
250
251
252
        return features


@dataclass
luopl's avatar
luopl committed
253
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
254
    r"""Data collator for pairwise data."""
chenych's avatar
chenych committed
255

chenych's avatar
chenych committed
256
257
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
        r"""Pad batched data to the longest sequence in the batch.
chenych's avatar
chenych committed
258
259
260
261
262
263
264
265

        We generate 2 * n examples where the first n examples represent chosen examples and
        the last n examples represent rejected examples.
        """
        concatenated_features = []
        for key in ("chosen", "rejected"):
            for feature in features:
                target_feature = {
luopl's avatar
luopl committed
266
267
268
                    "input_ids": feature[f"{key}_input_ids"],
                    "attention_mask": feature[f"{key}_attention_mask"],
                    "labels": feature[f"{key}_labels"],
luopl's avatar
luopl committed
269
270
                    "images": feature["images"],
                    "videos": feature["videos"],
chenych's avatar
chenych committed
271
                    "audios": feature["audios"],
chenych's avatar
chenych committed
272
273
274
275
276
277
278
                }
                concatenated_features.append(target_feature)

        return super().__call__(concatenated_features)


@dataclass
luopl's avatar
luopl committed
279
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
280
    r"""Data collator for KTO data."""
chenych's avatar
chenych committed
281

chenych's avatar
chenych committed
282
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
283
284
285
286
287
288
289
290
        target_features = []
        kl_features = []
        kto_tags = []
        for feature in features:
            target_feature = {
                "input_ids": feature["input_ids"],
                "attention_mask": feature["attention_mask"],
                "labels": feature["labels"],
luopl's avatar
luopl committed
291
292
                "images": feature["images"],
                "videos": feature["videos"],
chenych's avatar
chenych committed
293
                "audios": feature["audios"],
chenych's avatar
chenych committed
294
295
296
297
298
            }
            kl_feature = {
                "input_ids": feature["kl_input_ids"],
                "attention_mask": feature["kl_attention_mask"],
                "labels": feature["kl_labels"],
luopl's avatar
luopl committed
299
300
                "images": feature["images"],
                "videos": feature["videos"],
chenych's avatar
chenych committed
301
                "audios": feature["audios"],
chenych's avatar
chenych committed
302
303
304
305
306
307
308
309
310
311
            }
            target_features.append(target_feature)
            kl_features.append(kl_feature)
            kto_tags.append(feature["kto_tags"])

        batch = super().__call__(target_features)
        kl_batch = super().__call__(kl_features)
        batch["kl_input_ids"] = kl_batch["input_ids"]
        batch["kl_attention_mask"] = kl_batch["attention_mask"]
        batch["kl_labels"] = kl_batch["labels"]
chenych's avatar
chenych committed
312
313
        if "cross_attention_mask" in kl_batch:  # for mllama inputs.
            batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
luopl's avatar
luopl committed
314
        if "token_type_ids" in kl_batch:
chenych's avatar
chenych committed
315
316
317
318
            batch["kl_token_type_ids"] = kl_batch["token_type_ids"]

        batch["kto_tags"] = torch.tensor(kto_tags)
        return batch