"README_ORIGIN.md" did not exist on "b73cc99425f15a6f5e6abe5094e48f81974047de"
collator.py 13.9 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Literal, Optional
chenych's avatar
chenych committed
20

chenych's avatar
chenych committed
21
import numpy as np
chenych's avatar
chenych committed
22
import torch
luopl's avatar
luopl committed
23
import torch.nn.functional as F
chenych's avatar
chenych committed
24
from peft import PeftModel
chenych's avatar
chenych committed
25
26
from transformers import DataCollatorForSeq2Seq

chenych's avatar
chenych committed
27
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
luopl's avatar
luopl committed
28
29
30
31
32
33
from ..extras.packages import is_pillow_available


if is_pillow_available():
    from PIL import Image

chenych's avatar
chenych committed
34

luopl's avatar
luopl committed
35
36
37
38
39
40
if TYPE_CHECKING:
    from transformers import ProcessorMixin

    from .template import Template


chenych's avatar
chenych committed
41
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
chenych's avatar
chenych committed
42
43
44
45
    r"""Expand 2d attention mask to 4d attention mask.

    Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
    handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
chenych's avatar
chenych committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    e.g.
    ```python
    # input
    [[1, 1, 2, 2, 2, 0]]
    # output
    [
        [
            [
                [o, x, x, x, x, x],
                [o, o, x, x, x, x],
                [x, x, o, x, x, x],
                [x, x, o, o, x, x],
                [x, x, o, o, o, x],
                [x, x, x, x, x, x],
            ]
        ]
    ]
    ```
    where `o` equals to `0.0`, `x` equals to `min_dtype`.
    """
chenych's avatar
chenych committed
67
    _, seq_len = attention_mask_with_indices.size()
chenych's avatar
chenych committed
68
    min_dtype = torch.finfo(dtype).min
chenych's avatar
chenych committed
69
    zero_tensor = torch.tensor(0, dtype=dtype)
chenych's avatar
chenych committed
70
71

    # Create a non-padding mask.
chenych's avatar
chenych committed
72
    non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
chenych's avatar
chenych committed
73
74
75
76
    # Create indices for comparison.
    indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2)  # [bsz, 1, 1, seq_len]
    indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3)  # [bsz, 1, seq_len, 1]
    # Create a lower triangular mask.
chenych's avatar
chenych committed
77
78
    tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
    attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
chenych's avatar
chenych committed
79
    # Invert the attention mask.
chenych's avatar
chenych committed
80
    attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
chenych's avatar
chenych committed
81
82
83
84
    return attention_mask_4d


@dataclass
luopl's avatar
luopl committed
85
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chenych's avatar
chenych committed
86
    r"""Data collator that supports VLMs.
luopl's avatar
luopl committed
87

chenych's avatar
chenych committed
88
    Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
luopl's avatar
luopl committed
89
90
91
92
93
    """

    template: Optional["Template"] = None
    processor: Optional["ProcessorMixin"] = None

luopl's avatar
luopl committed
94
95
96
97
    def __post_init__(self):
        if self.template is None:
            raise ValueError("Template is required for MultiModalDataCollator.")

chenych's avatar
chenych committed
98
99
100
101
102
103
104
105
106
107
        if isinstance(self.model, PeftModel):
            self.model = self.model.base_model.model

        if self.model is not None and hasattr(self.model, "get_rope_index"):  # for qwen2vl mrope
            self.get_rope_func = self.model.get_rope_index  # transformers < 4.52.0 or qwen2.5 omni
        elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
            self.get_rope_func = self.model.model.get_rope_index  # transformers >= 4.52.0
        else:
            self.get_rope_func = None

chenych's avatar
chenych committed
108
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
109
110
        batch_images, batch_videos, batch_audios = [], [], []
        batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
luopl's avatar
luopl committed
111
112
113
        for feature in features:
            images = feature.pop("images", None) or []
            videos = feature.pop("videos", None) or []
chenych's avatar
chenych committed
114
            audios = feature.pop("audios", None) or []
luopl's avatar
luopl committed
115
116
            batch_images.extend(images)
            batch_videos.extend(videos)
chenych's avatar
chenych committed
117
            batch_audios.extend(audios)
luopl's avatar
luopl committed
118
119
            batch_imglens.append(len(images))
            batch_vidlens.append(len(videos))
chenych's avatar
chenych committed
120
            batch_audlens.append(len(audios))
luopl's avatar
luopl committed
121
            batch_input_ids.append(feature["input_ids"])
luopl's avatar
luopl committed
122

chenych's avatar
chenych committed
123
        fake_input_ids = []
luopl's avatar
luopl committed
124
        if (
chenych's avatar
chenych committed
125
            self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
luopl's avatar
luopl committed
126
127
128
        ):  # avoid process hanging in zero3/fsdp case
            fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
            fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
chenych's avatar
chenych committed
129
130
            fake_messages = self.template.mm_plugin.process_messages(
                fake_messages, fake_images, [], [], self.processor
luopl's avatar
luopl committed
131
            )
chenych's avatar
chenych committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
            _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
                _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
            )
            fake_input_ids.extend(_fake_input_ids)
            batch_images = fake_images
            batch_imglens[0] = 1

        if (
            self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
        ):  # avoid process hanging in zero3/fsdp case
            fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
            fake_audios = [np.zeros(1600)]
            fake_messages = self.template.mm_plugin.process_messages(
                fake_messages, [], [], fake_audios, self.processor
            )
            _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
            _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
                _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
            )
            fake_input_ids.extend(_fake_input_ids)
            batch_audios = fake_audios
            batch_audlens[0] = 1

        if len(fake_input_ids) != 0:
luopl's avatar
luopl committed
157
158
159
160
161
162
163
164
165
166
167
            if self.tokenizer.padding_side == "right":
                features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
                features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
                features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
            else:
                features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
                features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
                features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]

            batch_input_ids[0] = features[0]["input_ids"]

luopl's avatar
luopl committed
168
        mm_inputs = self.template.mm_plugin.get_mm_inputs(
chenych's avatar
chenych committed
169
170
171
172
173
174
175
176
            batch_images,
            batch_videos,
            batch_audios,
            batch_imglens,
            batch_vidlens,
            batch_audlens,
            batch_input_ids,
            self.processor,
luopl's avatar
luopl committed
177
178
179
180
181
182
        )
        if "token_type_ids" in mm_inputs:
            token_type_ids = mm_inputs.pop("token_type_ids")
            for i, feature in enumerate(features):
                feature["token_type_ids"] = token_type_ids[i]

chenych's avatar
chenych committed
183
        features: dict[str, torch.Tensor] = super().__call__(features)
luopl's avatar
luopl committed
184

chenych's avatar
chenych committed
185
        if self.get_rope_func is not None:
chenych's avatar
chenych committed
186
187
188
189
            rope_index_kwargs = {
                "input_ids": features["input_ids"],
                "image_grid_thw": mm_inputs.get("image_grid_thw"),
                "video_grid_thw": mm_inputs.get("video_grid_thw"),
chenych's avatar
chenych committed
190
                "attention_mask": (features["attention_mask"] >= 1).float(),
chenych's avatar
chenych committed
191
            }
chenych's avatar
chenych committed
192
            if "second_per_grid_ts" in mm_inputs:  # for qwen2vl
chenych's avatar
chenych committed
193
                rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
chenych's avatar
chenych committed
194
            elif "video_second_per_grid" in mm_inputs:  # for qwen2.5 omni
chenych's avatar
chenych committed
195
196
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")

chenych's avatar
chenych committed
197
            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker":  # for qwen2.5 omni
chenych's avatar
chenych committed
198
                rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
chenych's avatar
chenych committed
199
                feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
chenych's avatar
chenych committed
200
201
                if feature_attention_mask is not None:  # FIXME: need to get video image lengths
                    audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
chenych's avatar
chenych committed
202
203
                    rope_index_kwargs["audio_seqlens"] = audio_feature_lengths  # prepare for input

chenych's avatar
chenych committed
204
205
206
207
                features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
                features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
                    dim=-1
                ).unsqueeze(-1)
chenych's avatar
chenych committed
208
            else:  # for qwen2vl
chenych's avatar
chenych committed
209
210
211
212
213
214
215
216
                features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)

        if (
            self.model is not None
            and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
            and ("position_ids" not in features or features["position_ids"].dim() != 3)
        ):
            raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
luopl's avatar
luopl committed
217
218
219
220
221
222
223

        if "cross_attention_mask" in mm_inputs:  # for mllama inputs when pad_to_multiple_of is enabled
            cross_attention_mask = mm_inputs.pop("cross_attention_mask")
            seq_len = features["input_ids"].size(1)
            orig_len = cross_attention_mask.size(1)
            mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))

luopl's avatar
luopl committed
224
        features.update(mm_inputs)
luopl's avatar
luopl committed
225

luopl's avatar
luopl committed
226
227
228
229
230
        if "image_bound" in features:  # for minicpmv inputs
            bsz, seq_length = features["input_ids"].shape
            features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
            return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}

luopl's avatar
luopl committed
231
232
233
234
235
        return features


@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
236
    r"""Data collator for 4d attention mask."""
chenych's avatar
chenych committed
237
238
239
240
241

    block_diag_attn: bool = False
    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
    compute_dtype: "torch.dtype" = torch.float32

chenych's avatar
chenych committed
242
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
243
244
245
246
        features = super().__call__(features)
        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)

luopl's avatar
luopl committed
247
248
249
250
        for key, value in features.items():  # cast data dtype for paligemma
            if torch.is_tensor(value) and torch.is_floating_point(value):
                features[key] = value.to(self.compute_dtype)

chenych's avatar
chenych committed
251
252
253
254
        return features


@dataclass
luopl's avatar
luopl committed
255
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
256
    r"""Data collator for pairwise data."""
chenych's avatar
chenych committed
257

chenych's avatar
chenych committed
258
259
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
        r"""Pad batched data to the longest sequence in the batch.
chenych's avatar
chenych committed
260
261
262
263
264
265
266
267

        We generate 2 * n examples where the first n examples represent chosen examples and
        the last n examples represent rejected examples.
        """
        concatenated_features = []
        for key in ("chosen", "rejected"):
            for feature in features:
                target_feature = {
luopl's avatar
luopl committed
268
269
270
                    "input_ids": feature[f"{key}_input_ids"],
                    "attention_mask": feature[f"{key}_attention_mask"],
                    "labels": feature[f"{key}_labels"],
luopl's avatar
luopl committed
271
272
                    "images": feature["images"],
                    "videos": feature["videos"],
chenych's avatar
chenych committed
273
                    "audios": feature["audios"],
chenych's avatar
chenych committed
274
275
276
277
278
279
280
                }
                concatenated_features.append(target_feature)

        return super().__call__(concatenated_features)


@dataclass
luopl's avatar
luopl committed
281
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
chenych's avatar
chenych committed
282
    r"""Data collator for KTO data."""
chenych's avatar
chenych committed
283

chenych's avatar
chenych committed
284
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
chenych's avatar
chenych committed
285
286
287
288
289
290
291
292
        target_features = []
        kl_features = []
        kto_tags = []
        for feature in features:
            target_feature = {
                "input_ids": feature["input_ids"],
                "attention_mask": feature["attention_mask"],
                "labels": feature["labels"],
luopl's avatar
luopl committed
293
294
                "images": feature["images"],
                "videos": feature["videos"],
chenych's avatar
chenych committed
295
                "audios": feature["audios"],
chenych's avatar
chenych committed
296
297
298
299
300
            }
            kl_feature = {
                "input_ids": feature["kl_input_ids"],
                "attention_mask": feature["kl_attention_mask"],
                "labels": feature["kl_labels"],
luopl's avatar
luopl committed
301
302
                "images": feature["images"],
                "videos": feature["videos"],
chenych's avatar
chenych committed
303
                "audios": feature["audios"],
chenych's avatar
chenych committed
304
305
306
307
308
309
310
311
312
313
            }
            target_features.append(target_feature)
            kl_features.append(kl_feature)
            kto_tags.append(feature["kto_tags"])

        batch = super().__call__(target_features)
        kl_batch = super().__call__(kl_features)
        batch["kl_input_ids"] = kl_batch["input_ids"]
        batch["kl_attention_mask"] = kl_batch["attention_mask"]
        batch["kl_labels"] = kl_batch["labels"]
chenych's avatar
chenych committed
314
        if "cross_attention_mask" in kl_batch:  # for mllama inputs
chenych's avatar
chenych committed
315
            batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
chenych's avatar
chenych committed
316

luopl's avatar
luopl committed
317
        if "token_type_ids" in kl_batch:
chenych's avatar
chenych committed
318
319
320
321
            batch["kl_token_type_ids"] = kl_batch["token_type_ids"]

        batch["kto_tags"] = torch.tensor(kto_tags)
        return batch