mm_utils.py 27.3 KB
Newer Older
1
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
2
Multi-modality utils
3
4
"""

5
import hashlib
6
7
8
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple

9
import numpy as np
10
11
12
import torch
from torch import nn

13
from sglang.srt.layers.multimodal import gpu_tensor_hash
14
from sglang.srt.managers.schedule_batch import (
15
    Modality,
Mick's avatar
Mick committed
16
    MultimodalDataItem,
Mick's avatar
Mick committed
17
    MultimodalInputs,
18
19
    global_server_args_dict,
)
20
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
21
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
from sglang.srt.utils import flatten_nested_list, print_warning_once
23
from sglang.utils import logger
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
28
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.
29
30
31
32
33
34
35
36
37
38
39


class MultiModalityDataPaddingPattern:
    """
    Data tokens (like image tokens) often need special handling during padding
    to maintain model compatibility. This class provides the interface for
    implementing different padding strategies for data tokens
    """

    @abstractmethod
    def pad_input_tokens(
Mick's avatar
Mick committed
40
        self, input_ids: List[int], mm_inputs: MultimodalInputs
41
42
43
44
45
46
47
48
49
50
    ) -> List[int]:
        """
        Pad the input ids sequence containing data tokens, and replace them with pad_values
        """
        pass


class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
    """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)

51
52
    The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value

53
54
55
    This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
    """

56
57
58
59
60
61
62
63
64
65
66
    def __init__(
        self,
        data_token_pairs: Optional[List[Tuple[int, int]]],
        data_start_token_ids: Optional[List[int]] = None,
    ) -> None:
        """

        Args:
            data_start_token_ids marks the start of a single multimodal data
            See Minicpmo's slice_start_id for example
        """
67
        self.data_token_id_pairs = data_token_pairs
68
69
70
        self.data_start_token_ids = data_start_token_ids or [
            s for s, _e in data_token_pairs
        ]
71
72

    def pad_input_tokens(
Mick's avatar
Mick committed
73
        self, input_ids: List[int], mm_inputs: MultimodalInputs
74
75
    ) -> List[int]:
        """
76
        This function will replace the data-tokens in between with pad_values accordingly
77
        """
Mick's avatar
Mick committed
78
        pad_values = [item.pad_value for item in mm_inputs.mm_items]
79
        data_token_pairs = self.data_token_id_pairs
Mick's avatar
Mick committed
80
        mm_inputs.data_offsets = []
81
        if data_token_pairs is None:
Mick's avatar
Mick committed
82
            data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
83
        if data_token_pairs is None:
Mick's avatar
Mick committed
84
            print_warning_once(
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
                "No data_token_pairs provided, RadixAttention might be influenced."
            )
            return input_ids
        start_token_ids = [s for s, _e in data_token_pairs]
        end_tokens_ids = [e for _s, e in data_token_pairs]

        padded_ids = []
        last_idx = 0
        data_idx = -1

        start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
        end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]

        if len(start_indices) != len(end_indices):
            return input_ids

        for start_idx, end_idx in zip(start_indices, end_indices):
            padded_ids.extend(input_ids[last_idx : start_idx + 1])

104
            if input_ids[start_idx] in self.data_start_token_ids:
105
                data_idx += 1
Mick's avatar
Mick committed
106
                mm_inputs.data_offsets += [start_idx]
Mick's avatar
Mick committed
107

Mick's avatar
Mick committed
108
109
            if data_idx >= len(pad_values):
                data_idx = len(pad_values) - 1
110
111
112
113
114
115
116
117
118

            num_tokens = end_idx - start_idx - 1
            pad_value = pad_values[data_idx]
            padded_ids.extend([pad_value] * num_tokens)

            last_idx = end_idx

        padded_ids.extend(input_ids[last_idx:])

Mick's avatar
Mick committed
119
        assert len(input_ids) == len(padded_ids), "Length validation fails"
120
121
122
        return padded_ids


123
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
Mick's avatar
Mick committed
124
125
126
    """In this pattern, data tokens should be represented as repetitions of a single token
    e.g. <image><image>....<image>, or <audio><audio>...<audio>
    """
127

128
129
    def __init__(self, token_ids: List[int]) -> None:
        self.token_ids = token_ids
130

131
132
133
    def pad_input_tokens(
        self, input_ids: List[int], mm_inputs: MultimodalInputs
    ) -> List[int]:
134
        """
135
136
        Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
        and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
137
        """
Mick's avatar
Mick committed
138
        pad_values = [item.pad_value for item in mm_inputs.mm_items]
139
140
141
142
143
        if not pad_values:
            # No multimodal items, return original input_ids
            return input_ids
        if not input_ids:
            return []
144
145

        input_ids_tensor = torch.tensor(input_ids)
146
147
148
        device = input_ids_tensor.device
        token_ids_tensor = torch.tensor(self.token_ids, device=device)
        mask = torch.isin(input_ids_tensor, token_ids_tensor)
149

150
151
152
        if not mask.any():
            # No tokens match token_ids, return original input_ids
            return input_ids
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # Find contiguous regions
        padded_mask = torch.cat(
            (
                torch.tensor([False], device=device),
                mask,
                torch.tensor([False], device=device),
            )
        )
        # Find indices where the mask value changes
        diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]

        # Start indices are where False changes to True
        starts = diff_indices[::2]
        # End indices are where True changes to False (exclusive index)
        ends = diff_indices[1::2]

        # Check if the number of regions matches the number of pad values
        if len(starts) != len(pad_values):
            # Maybe log a warning here?
            num_regions = len(starts)
            num_pad_values = len(pad_values)
            if num_regions > 0 and num_pad_values > 0:
                pad_values = (pad_values * (num_regions // num_pad_values + 1))[
                    :num_regions
                ]
            else:  # If no regions or no pad_values, this loop won't run anyway.
                pad_values = []  # Ensure pad_values is empty if starts is empty

        # Create a copy to modify
        output_ids_tensor = input_ids_tensor.clone()

        # Replace tokens in each region with the corresponding pad value
        # Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
        for i in range(min(len(starts), len(pad_values))):
            start_idx = starts[i]
            end_idx = ends[i]
            pad_value = pad_values[i]
            if pad_value is not None:  # Ensure pad_value is not None before assignment
                output_ids_tensor[start_idx:end_idx] = pad_value
            else:
                logger.warning(f"Skipping region {i} due to None pad_value.")
        return output_ids_tensor.tolist()
196
197


198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
embedding_cache = None


def init_embedding_cache(max_size: int):
    global embedding_cache
    embedding_cache = MultiModalCache(max_size)


def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
    hash_list = [item.hash for item in embedding_items]
    return hash(tuple(hash_list))


def get_embedding_chunk(
    embedding: torch.Tensor,
    extend_prefix_len: int,
    extend_seq_len: int,
    items_offset: List[Tuple[int, int]],
) -> Tuple[torch.Tensor, int, int]:
    """
    Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.

    Args:
        embedding: The full embedding tensor to extract a chunk from
        extend_prefix_len: The starting position (prefix length) for extraction
        extend_seq_len: The number of tokens to extract
        items_offset: List of [start, end] offset ranges for multimodal items in the input sequence

    Returns:
        A tuple containing:
        - The extracted embedding chunk as a tensor
        - The start index used for extraction
        - The end index used for extraction

    Note:
        If there's no overlap between the requested range and the offset ranges,
        an empty tensor is returned with zeros for start and end indices.
    """
    start_index, end_index = 0, 0
    extend_start_index = extend_prefix_len
    extend_end_index = extend_prefix_len + extend_seq_len - 1

    for start, end in items_offset:
        if extend_start_index >= start and extend_start_index <= end:
            start_index += extend_start_index - start
        elif extend_start_index > end:
            start_index += end - start + 1

        if extend_end_index >= start and extend_end_index <= end:
            end_index += extend_end_index - start + 1
        elif extend_end_index > end:
            end_index += end - start + 1
    # some models embedding is 3-dim, reshape it to 2-dim
    embedding = embedding.reshape(-1, embedding.shape[-1])
    embedding_chunk = embedding[start_index:end_index]
    return embedding_chunk, start_index, end_index


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def _get_precomputed_embedding(
    items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]:
    """
    If all items have precomputed_features, return their concatenation.
    If some but not all have precomputed_features, raise NotImplementedError.
    If none have precomputed_features, return None.
    """
    precomputed_features = [item.precomputed_features for item in items]
    if any(feature is not None for feature in precomputed_features):
        if not all(feature is not None for feature in precomputed_features):
            raise NotImplementedError(
                "MM inputs where only some items are precomputed."
            )
        result = torch.concat(precomputed_features)
        # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
        result = result.reshape(-1, result.shape[-1])
        return result
    return None


def _get_chunked_prefill_embedding(
Mick's avatar
Mick committed
278
279
    data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
    embedding_items: List[MultimodalDataItem],
280
281
282
283
    items_size: List[int],
    prefix_length: List[int],
    extend_length: List[int],
    items_offset_list: List[List[Tuple[int, int]]],
284
285
) -> Optional[torch.Tensor]:
    # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    embedding_list = []
    for i in range(len(items_size) - 1):
        if items_size[i] == items_size[i + 1]:
            continue
        embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
        items_offset = items_offset_list[i]
        embedding_items_hash = get_embedding_hash(embedding_items_per_req)
        # if all items has been prefixed, we do not need to calculate embedding
        if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
            continue
        embedding_per_req = embedding_cache.get(embedding_items_hash)
        if embedding_per_req is None:
            embedding_per_req = data_embedding_func(embedding_items_per_req)
            if not embedding_cache.put(embedding_items_hash, embedding_per_req):
                print_warning_once(
                    "Multimodal embedding cache is full. Consider increasing the "
                    "`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
                )
Mick's avatar
Mick committed
304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        embedding_per_req_chunk, _, end_index = get_embedding_chunk(
            embedding=embedding_per_req,
            extend_prefix_len=prefix_length[i],
            extend_seq_len=extend_length[i],
            items_offset=items_offset,
        )
        # remove this item from cache if chunk reaches to the end
        embedding_per_req_length = (
            embedding_per_req.shape[0]
            if embedding_per_req.dim() == 2
            else embedding_per_req.shape[0] * embedding_per_req.shape[1]
        )
        if end_index == embedding_per_req_length:
            embedding_cache.free(embedding_items_hash)
        embedding_list.append(embedding_per_req_chunk)
    if len(embedding_list) == 0:
321
322
323
324
325
326
327
328
329
        return None
    return torch.concat(embedding_list, dim=0)


def _get_multimodal_mask(
    input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
) -> torch.Tensor:
    return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)

Mick's avatar
Mick committed
330

331
332
333
334
335
336
337
def _adjust_embedding_length(
    embedding: torch.Tensor,
    mask: torch.Tensor,
    logger,
) -> torch.Tensor:
    num_mm_tokens_in_embedding = embedding.shape[0]
    num_mm_tokens_in_input_ids = mask.sum().item()
Mick's avatar
Mick committed
338
339
    if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
        logger.warning(
340
            f"Number of tokens in multimodal embedding does not match those in the input text. "
Mick's avatar
Mick committed
341
            f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
342
            f"tokens from multimodal embeddings."
Mick's avatar
Mick committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        )
        if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
            chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
            if chunked_prefill_size != -1:
                logger.warning(
                    "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
                )
            # extract from the end: this is a compromise
            if embedding.dim() == 2:
                embedding = embedding[-num_mm_tokens_in_input_ids:, :]
            else:
                num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
                embedding = embedding[-num_multimodal:, :]
        else:
            raise RuntimeError(
Mick's avatar
Mick committed
358
                f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
Mick's avatar
Mick committed
359
            )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    return embedding


def get_embedding_and_mask(
    data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
    embedding_items: List[MultimodalDataItem],
    placeholder_tensor: torch.Tensor,
    input_ids: torch.Tensor,
    items_size: List[int],
    prefix_length: List[int],
    extend_length: List[int],
    items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Mick's avatar
Mick committed
375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    Args:
        data_embedding_func: Function that generates embeddings for multimodal items
        embedding_items: List of multimodal items to embed
        placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
        input_ids: The input token IDs tensor
        items_size: Cumulative sizes of multimodal items per request
        prefix_length: Prefix lengths for each request
        extend_length: Sequence lengths for each request
        items_offset_list: List of offset ranges for multimodal items in each request

    Returns:
        A tuple containing:
        - The generated embeddings tensor
        - A boolean mask tensor indicating where these embeddings should be placed
    """
    # 1. Get embedding
    embedding = _get_precomputed_embedding(embedding_items)
    if embedding is None:
        embedding = _get_chunked_prefill_embedding(
            data_embedding_func,
            embedding_items,
            items_size,
            prefix_length,
            extend_length,
            items_offset_list,
        )
        if embedding is None:
            return None, None
    # 2. Get mask
    special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
    # 3. Adjust embedding length if needed
    embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
Mick's avatar
Mick committed
408
409
410
    return embedding, special_multimodal_mask


Mick's avatar
Mick committed
411
def embed_mm_inputs(
412
413
414
    mm_inputs_list: List[MultimodalInputs],
    extend_prefix_lens: List[int],
    extend_seq_lens: List[int],
415
416
    input_ids: torch.Tensor,
    input_embedding: nn.Embedding,
Mick's avatar
Mick committed
417
418
419
420
421
422
    image_data_embedding_func: Callable[
        [List[MultimodalDataItem]], torch.Tensor
    ] = None,
    audio_data_embedding_func: Callable[
        [List[MultimodalDataItem]], torch.Tensor
    ] = None,
423
    placeholder_tokens: dict[Modality, List[int]] = None,
424
425
) -> Optional[torch.Tensor]:
    """
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    Embed multimodal inputs and integrate them with text token embeddings.

    Args:
        mm_inputs_list: List of multimodal inputs to process
        extend_prefix_lens: Prefix lengths for each request
        extend_seq_lens: Sequence lengths for each request
        input_ids: Input token IDs tensor
        input_embedding: Embedding layer for text tokens
        image_data_embedding_func: Function to embed image data
        audio_data_embedding_func: Function to embed audio data
        placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)

    Returns:
        Combined embedding tensor with multimodal content integrated
440
    """
Mick's avatar
Mick committed
441

442
    if mm_inputs_list is None:
443
444
        return None

Mick's avatar
Mick committed
445
446
    # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
    # we assume that multimodal data are represented with its pad_values in input_ids
447
448
449
    item_flatten_list = []
    for mm_inputs in mm_inputs_list:
        item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
450

451
    embeddings, masks = [], []
452

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    # 2. Get multimodal embedding separately
    # TODO: make this more generic
    # Try get image embedding if any
    if (
        any(True for item in item_flatten_list if item.is_image())
        and image_data_embedding_func
    ):
        items = [item for item in item_flatten_list if item.is_image()]
        placeholder_tensor = torch.tensor(
            [item.pad_value for item in items],
            device=input_ids.device,
        )
        # calculate per request items length offset
        items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
        items_offsets = []
        for i, mm_inputs in enumerate(mm_inputs_list):
            image_items = [item for item in mm_inputs.mm_items if item.is_image()]
            items_size[i + 1] = len(image_items)
            items_offsets.append(
                flatten_nested_list(
                    [
                        item.image_offsets
                        for item in mm_inputs.mm_items
                        if item.is_image()
                    ]
                )
            )
        items_size = torch.cumsum(items_size, dim=0).tolist()
Mick's avatar
Mick committed
481

482
483
484
485
486
487
488
489
490
491
492
493
        embedding, mask = get_embedding_and_mask(
            data_embedding_func=image_data_embedding_func,
            embedding_items=items,
            placeholder_tensor=placeholder_tensor,
            input_ids=input_ids,
            items_size=items_size,
            prefix_length=extend_prefix_lens,
            extend_length=extend_seq_lens,
            items_offset_list=items_offsets,
        )
        embeddings += [embedding]
        masks += [mask]
494

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    # Try get audio embedding if any
    if (
        any(True for item in item_flatten_list if item.is_audio())
        and audio_data_embedding_func
    ):
        items = [item for item in item_flatten_list if item.is_audio()]
        placeholder_tensor = torch.tensor(
            [item.pad_value for item in items],
            device=input_ids.device,
        )
        items_offsets = []
        # calculate per request items length offset
        items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
        for i, mm_inputs in enumerate(mm_inputs_list):
            audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
            items_size[i + 1] = len(audio_items)
            items_offsets.append(
                flatten_nested_list(
                    [
                        item.audio_offsets
                        for item in mm_inputs.mm_items
                        if item.is_audio()
                    ]
                )
Mick's avatar
Mick committed
519
            )
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        items_size = torch.cumsum(items_size, dim=0)

        embedding, mask = get_embedding_and_mask(
            data_embedding_func=audio_data_embedding_func,
            embedding_items=items,
            placeholder_tensor=placeholder_tensor,
            input_ids=input_ids,
            items_size=items_size,
            prefix_length=extend_prefix_lens,
            extend_length=extend_seq_lens,
            items_offset_list=items_offsets,
        )
        embeddings += [embedding]
        masks += [mask]

    # 3. Get input embeddings
    vocab_size = input_embedding.num_embeddings
    # Important: clamp after getting original multimodal regions
    # Clamp input ids. This is because the input_ids for the multimodal tokens are
    # filled with the hash values of the multimodal for the prefix matching in the radix attention.
    # There values are useless because their embeddings will be replaced by vision embeddings anyway.
    input_ids.clamp_(min=0, max=vocab_size - 1)
    inputs_embeds = input_embedding(input_ids)

    # 4. scatter embeddings into input embedding
    for embedding, mask in zip(embeddings, masks):
        if embedding is None or mask is None:
            continue
        mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
        inputs_embeds = inputs_embeds.masked_scatter(
            mask,
            embedding.to(inputs_embeds.device, inputs_embeds.dtype),
        )
553
554
555
556
557
558
    return inputs_embeds


def general_mm_embed_routine(
    input_ids: torch.Tensor,
    forward_batch: ForwardBatch,
Mick's avatar
Mick committed
559
    language_model: nn.Module,
560
561
    image_data_embedding_func: Optional[
        Callable[[List[MultimodalDataItem]], torch.Tensor]
Mick's avatar
Mick committed
562
    ] = None,
563
564
    audio_data_embedding_func: Optional[
        Callable[[List[MultimodalDataItem]], torch.Tensor]
Mick's avatar
Mick committed
565
    ] = None,
566
    placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
Mick's avatar
Mick committed
567
568
    **kwargs,
) -> torch.Tensor:
569
    """
570
571
572
573
574
575
576
577
578
579
580
581
582
    Process multimodal inputs and forward through language model.

    Args:
        input_ids: Input token IDs tensor
        forward_batch: Batch information for model forward pass
        language_model: Base language model to use
        image_data_embedding_func: Function to embed image data
        audio_data_embedding_func: Function to embed audio data
        placeholder_tokens: Token IDs for multimodal placeholders
        **kwargs: Additional arguments passed to language model

    Returns:
        Hidden states from language model forward pass
583
    """
Mick's avatar
Mick committed
584
585
    assert hasattr(language_model, "get_input_embeddings")
    embed_tokens = language_model.get_input_embeddings()
586
    if (
Mick's avatar
Mick committed
587
588
        not forward_batch.forward_mode.is_decode()
        and forward_batch.contains_mm_inputs()
589
    ):
590
591
592
593
594
595
596
597
598
599
600
601
602
        mm_inputs_list = [
            mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
        ]
        extend_prefix_lens = [
            prefix_len
            for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
            if forward_batch.mm_inputs[i] is not None
        ]
        extend_seq_lens = [
            seq_len
            for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
            if forward_batch.mm_inputs[i] is not None
        ]
Mick's avatar
Mick committed
603
        inputs_embeds = embed_mm_inputs(
604
605
606
            mm_inputs_list=mm_inputs_list,
            extend_prefix_lens=extend_prefix_lens,
            extend_seq_lens=extend_seq_lens,
607
608
            input_ids=input_ids,
            input_embedding=embed_tokens,
Mick's avatar
Mick committed
609
610
            image_data_embedding_func=image_data_embedding_func,
            audio_data_embedding_func=audio_data_embedding_func,
611
            placeholder_tokens=placeholder_tokens,
612
        )
613
        # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
614
        # just being defensive here
Mick's avatar
Mick committed
615
616
617
618
        forward_batch.mm_inputs = None
    else:
        inputs_embeds = embed_tokens(input_ids)

Mick's avatar
Mick committed
619
620
621
622
623
624
625
    hidden_states = language_model(
        input_ids=None,
        forward_batch=forward_batch,
        input_embeds=inputs_embeds,
        **kwargs,
    )
    return hidden_states
Mick's avatar
Mick committed
626
627
628
629
630
631
632
633
634
635
636


def get_multimodal_data_bounds(
    input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
    """
    Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)

    Returns:
        [bounds_count, 2]
    """
Mick's avatar
Mick committed
637
    # All the multimodal data in the batch should share the same special bound token ids.
Mick's avatar
Mick committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    start_tokens = [s for s, _e in token_pairs]
    end_tokens = [e for _s, e in token_pairs]

    assert all(isinstance(t, int) for t in start_tokens)
    assert all(isinstance(t, int) for t in end_tokens)

    start_cond = torch.isin(
        input_ids, torch.tensor(start_tokens, device=input_ids.device)
    )
    end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))

    (data_start_tokens,) = torch.where(start_cond)
    (data_end_tokens,) = torch.where(end_cond)

Mick's avatar
Mick committed
652
    # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
Mick's avatar
Mick committed
653
654
655
656
657
658
659
660
661
662
663
664
    if len(data_start_tokens) != len(data_end_tokens):
        if (
            len(data_start_tokens) + 1 == len(data_end_tokens)
            and input_ids[0] in pad_values
            and data_end_tokens[0] < data_start_tokens[0]
        ):
            data_start_tokens = torch.cat(
                [
                    torch.tensor([0], device=data_start_tokens.device),
                    data_start_tokens,
                ]
            )
Mick's avatar
Mick committed
665
    valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
Mick's avatar
Mick committed
666

Mick's avatar
Mick committed
667
    if valid_mm_data_nums == 0:
Mick's avatar
Mick committed
668
669
670
671
        return torch.zeros((0, 2), device=input_ids.device)

    # Filter out pairs where start_token >= end_token
    valid_pairs = []
Mick's avatar
Mick committed
672
    for i in range(valid_mm_data_nums):
Mick's avatar
Mick committed
673
674
675
676
677
678
679
680
681
682
683
        start_token = data_start_tokens[i]
        end_token = data_end_tokens[i]
        if start_token < end_token:
            valid_pairs.append((start_token + 1, end_token - 1))

    if not valid_pairs:
        return torch.zeros((0, 2), device=input_ids.device)

    # Convert valid pairs to tensor
    valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
    return valid_pairs_tensor
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732


def data_hash(data) -> int:
    hash_bytes = hashlib.sha256(data).digest()[:8]
    return int.from_bytes(hash_bytes, byteorder="big", signed=False)


def tensor_hash(tensor_list) -> int:
    """
    hash a tensor or a tensor list
    """
    tensor = tensor_list
    if isinstance(tensor_list, list):
        tensor_list = flatten_nested_list(tensor_list)
        tensor_list = [
            x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
        ]
        tensor = torch.concat(tensor_list)
    if tensor.is_cuda:
        return gpu_tensor_hash(tensor)
    tensor = tensor.detach().contiguous()

    if tensor.dtype == torch.bfloat16:
        # memoryview() doesn't support PyTorch's BFloat16 dtype
        tensor = tensor.float()

    assert isinstance(tensor, torch.Tensor)
    if tensor.is_cuda:
        # TODO: improve this
        tensor_cpu = tensor.cpu()
    else:
        tensor_cpu = tensor

    mv = memoryview(tensor_cpu.numpy())
    return data_hash(mv.tobytes())


def hash_feature(f):
    if isinstance(f, list):
        if isinstance(f[0], torch.Tensor):
            return tensor_hash(f)
        return data_hash(tuple(flatten_nested_list(f)))
    elif isinstance(f, np.ndarray):
        arr = np.ascontiguousarray(f)
        arr_bytes = arr.tobytes()
        return data_hash(arr_bytes)
    elif isinstance(f, torch.Tensor):
        return tensor_hash([f])
    return data_hash(f)