mm_utils.py 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
    Multimodality utils
"""

from abc import abstractmethod
from typing import Callable, List, Optional, Tuple

import torch
from torch import nn

from sglang.srt.managers.schedule_batch import (
Mick's avatar
Mick committed
12
    MultimodalInputs,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    global_server_args_dict,
    logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.utils import logger


class MultiModalityDataPaddingPattern:
    """
    Data tokens (like image tokens) often need special handling during padding
    to maintain model compatibility. This class provides the interface for
    implementing different padding strategies for data tokens
    """

    @abstractmethod
    def pad_input_tokens(
Mick's avatar
Mick committed
29
        self, input_ids: List[int], image_inputs: MultimodalInputs
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    ) -> List[int]:
        """
        Pad the input ids sequence containing data tokens, and replace them with pad_values
        """
        pass


class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
    """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)

    This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
    """

    def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
        self.data_token_id_pairs = data_token_pairs

    def pad_input_tokens(
Mick's avatar
Mick committed
47
        self, input_ids: List[int], mm_inputs: MultimodalInputs
48
49
50
51
    ) -> List[int]:
        """
        This function will replace the data-tokens inbetween with pad_values accordingly
        """
Mick's avatar
Mick committed
52
        pad_values = mm_inputs.pad_values
53
        data_token_pairs = self.data_token_id_pairs
Mick's avatar
Mick committed
54
        mm_inputs.image_offsets = []
55
        if data_token_pairs is None:
Mick's avatar
Mick committed
56
            data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        if data_token_pairs is None:
            logger.warning(
                "No data_token_pairs provided, RadixAttention might be influenced."
            )
            return input_ids
        start_token_ids = [s for s, _e in data_token_pairs]
        end_tokens_ids = [e for _s, e in data_token_pairs]

        padded_ids = []
        last_idx = 0
        data_idx = -1

        start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
        end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]

        if len(start_indices) != len(end_indices):
            return input_ids

        for start_idx, end_idx in zip(start_indices, end_indices):
            padded_ids.extend(input_ids[last_idx : start_idx + 1])

Mick's avatar
Mick committed
78
            if input_ids[start_idx] in start_token_ids:
79
                data_idx += 1
Mick's avatar
Mick committed
80
81
82
83
                mm_inputs.image_offsets += [start_idx]

            if data_idx >= len(mm_inputs.pad_values):
                data_idx = len(mm_inputs.pad_values) - 1
84
85
86
87
88
89
90
91
92

            num_tokens = end_idx - start_idx - 1
            pad_value = pad_values[data_idx]
            padded_ids.extend([pad_value] * num_tokens)

            last_idx = end_idx

        padded_ids.extend(input_ids[last_idx:])

Mick's avatar
Mick committed
93
        assert len(input_ids) == len(padded_ids), "Length validation fails"
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        return padded_ids


class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
    """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
         which needs first to be expanded to multiple tokens, then replaced with their padding values

    This strategy should be used when a single data token represents content that should
    be expanded to multiple tokens during processing.
    """

    def __init__(
        self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
    ) -> None:
        self.num_data_token_calc_func = num_data_token_calc_func

    def pad_input_tokens(
Mick's avatar
Mick committed
111
        self, input_ids: List[int], mm_inputs: MultimodalInputs
112
113
114
115
116
117
    ) -> List[int]:
        """
        This function will follow the procedure of:
            1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
            2. the padded data tokens will be replaced with their pad_values
        """
Mick's avatar
Mick committed
118
119
        image_grid_thws = mm_inputs.image_grid_thws
        pad_values = mm_inputs.pad_values
120
121

        image_indices = [
Mick's avatar
Mick committed
122
            idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
123
124
        ]

Mick's avatar
Mick committed
125
        mm_inputs.image_offsets = []
126
127
128

        input_ids_with_image = []
        for image_cnt, _ in enumerate(image_grid_thws):
Mick's avatar
Mick committed
129
            # print(f"image_cnt {image_cnt}")
130
131
132
133
134
135
136
137
            num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
            if image_cnt == 0:
                non_image_tokens = input_ids[: image_indices[image_cnt]]
            else:
                non_image_tokens = input_ids[
                    image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
                ]
            input_ids_with_image.extend(non_image_tokens)
Mick's avatar
Mick committed
138
            mm_inputs.image_offsets.append(len(input_ids_with_image))
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            pad_ids = pad_values * (
                (num_image_tokens + len(pad_values)) // len(pad_values)
            )
            input_ids_with_image.extend(pad_ids[:num_image_tokens])
        input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])

        return input_ids_with_image


class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
    """In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""

    def __init__(self, image_token_id: torch.Tensor) -> None:
        self.image_token_id = image_token_id

    def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
        """
        This function will replace the data-tokens in between with pad_values accordingly
        """
        pad_values = image_inputs.pad_values
        assert len(pad_values) != 0

        input_ids_tensor = torch.tensor(input_ids)
        mask = torch.isin(input_ids_tensor, self.image_token_id)

        num_image_tokens = mask.sum().item()
        repeated_pad_values = torch.tensor(pad_values).repeat(
            num_image_tokens // len(pad_values) + 1
        )[:num_image_tokens]

        input_ids_tensor[mask] = repeated_pad_values
        return input_ids_tensor.tolist()


Mick's avatar
Mick committed
173
174
def embed_mm_inputs(
    mm_input: MultimodalInputs,
175
176
    input_ids: torch.Tensor,
    input_embedding: nn.Embedding,
Mick's avatar
Mick committed
177
    mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
178
179
180
181
182
183
184
185
186
    placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
    """
    Calculate the image embeddings if necessary, then scatter the result with
    the help of a boolean mask denoting the embed locations

    Returns:
        final embedding: Optional[torch.Tensor]
    """
Mick's avatar
Mick committed
187
    if mm_input is None:
188
189
        return None

Mick's avatar
Mick committed
190
    placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
191
192
193
194
195
196
197
198

    # boolean masking the special tokens
    special_image_mask = torch.isin(
        input_ids,
        torch.tensor(placeholder_token_ids, device=input_ids.device),
    ).unsqueeze(-1)

    num_image_tokens_in_input_ids = special_image_mask.sum()
Mick's avatar
Mick committed
199
200
    # print(f"{num_image_tokens_in_input_ids}")
    # print(f"{input_ids}")
201

Mick's avatar
Mick committed
202
    # return
203
204
205
206
    if num_image_tokens_in_input_ids == 0:
        # unexpected
        inputs_embeds = input_embedding(input_ids)
    else:
Mick's avatar
Mick committed
207
208
209
210
        # print(f"Getting image feature")
        image_embedding = mm_data_embedding_func(mm_input)

        # print(f"image_embedding: {image_embedding.shape}")
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        if image_embedding.dim() == 2:
            num_image_tokens_in_embedding = image_embedding.shape[0]
        else:
            num_image_tokens_in_embedding = (
                image_embedding.shape[0] * image_embedding.shape[1]
            )
        if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
            num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
            image_embedding = image_embedding[:num_image, :]
            logger.warning(
                f"Number of images does not match number of special image tokens in the input text. "
                f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
                "tokens from image embeddings."
            )

            # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
            # a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
            # extend_start_loc and extend_seq_lens
            if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
                chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
                if chunked_prefill_size != -1:
                    logger.warning(
                        "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
                    )

        vocab_size = input_embedding.num_embeddings
        # Important: clamp after getting original image regions
        # Clamp input ids. This is because the input_ids for the image tokens are
        # filled with the hash values of the image for the prefix matching in the radix attention.
        # There values are useless because their embeddings will be replaced by vision embeddings anyway.
        input_ids.clamp_(min=0, max=vocab_size - 1)
        inputs_embeds = input_embedding(input_ids)

        special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
            inputs_embeds.device
        )

        inputs_embeds = inputs_embeds.masked_scatter(
            special_image_mask,
            image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
        )
    return inputs_embeds


def embed_image_embedding(
    inputs_embeds: torch.Tensor,
    image_embedding: torch.Tensor,
    image_bounds: torch.Tensor,
) -> torch.Tensor:
    """
    scatter image_embedding into inputs_embeds according to image_bounds
    """
    if len(image_bounds) > 0:
        image_indices = torch.stack(
            [
                torch.arange(start, end, dtype=torch.long)
                for start, end in image_bounds.tolist()
            ]
        ).to(inputs_embeds.device)

        inputs_embeds.scatter_(
            0,
            image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
            image_embedding.view(-1, image_embedding.shape[-1]),
        )
    return inputs_embeds


def general_mm_embed_routine(
    input_ids: torch.Tensor,
    forward_batch: ForwardBatch,
    embed_tokens: nn.Embedding,
Mick's avatar
Mick committed
284
    mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
285
286
287
288
289
    placeholder_token_ids: List[int] = None,
):
    """
    a general wrapper function to get final input embeds from multimodal models
    with a language model as causal model
Mick's avatar
Mick committed
290
291
292
293

        Args:
            placeholder_token_ids (List[int]): the ids of mm data placeholder tokens

294
295
    """
    if (
Mick's avatar
Mick committed
296
297
        not forward_batch.forward_mode.is_decode()
        and forward_batch.contains_mm_inputs()
298
    ):
Mick's avatar
Mick committed
299
300
301
        image = forward_batch.merge_mm_inputs()
        inputs_embeds = embed_mm_inputs(
            mm_input=image,
302
303
            input_ids=input_ids,
            input_embedding=embed_tokens,
Mick's avatar
Mick committed
304
            mm_data_embedding_func=mm_data_embedding_func,
305
306
            placeholder_token_ids=placeholder_token_ids,
        )
Mick's avatar
Mick committed
307
        # once used, mm_inputs is useless
308
        # just being defensive here
Mick's avatar
Mick committed
309
310
311
312
        forward_batch.mm_inputs = None
    else:
        inputs_embeds = embed_tokens(input_ids)

313
    return inputs_embeds
Mick's avatar
Mick committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373


def get_multimodal_data_bounds(
    input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
    """
    Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)

    Returns:
        [bounds_count, 2]
    """
    # All the images in the batch should share the same special image
    # bound token ids.
    start_tokens = [s for s, _e in token_pairs]
    end_tokens = [e for _s, e in token_pairs]

    assert all(isinstance(t, int) for t in start_tokens)
    assert all(isinstance(t, int) for t in end_tokens)

    # print(input_ids)
    start_cond = torch.isin(
        input_ids, torch.tensor(start_tokens, device=input_ids.device)
    )
    end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))

    (data_start_tokens,) = torch.where(start_cond)
    (data_end_tokens,) = torch.where(end_cond)

    # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
    if len(data_start_tokens) != len(data_end_tokens):
        if (
            len(data_start_tokens) + 1 == len(data_end_tokens)
            and input_ids[0] in pad_values
            and data_end_tokens[0] < data_start_tokens[0]
        ):
            data_start_tokens = torch.cat(
                [
                    torch.tensor([0], device=data_start_tokens.device),
                    data_start_tokens,
                ]
            )
    valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))

    if valid_image_nums == 0:
        return torch.zeros((0, 2), device=input_ids.device)

    # Filter out pairs where start_token >= end_token
    valid_pairs = []
    for i in range(valid_image_nums):
        start_token = data_start_tokens[i]
        end_token = data_end_tokens[i]
        if start_token < end_token:
            valid_pairs.append((start_token + 1, end_token - 1))

    if not valid_pairs:
        return torch.zeros((0, 2), device=input_ids.device)

    # Convert valid pairs to tensor
    valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
    return valid_pairs_tensor