minicpm.py 5.27 KB
Newer Older
Mick's avatar
Mick committed
1
2
3
4
from typing import List, Union

import torch

Mick's avatar
Mick committed
5
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
Mick's avatar
Mick committed
6
7
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
8
9
10
11
from sglang.srt.multimodal.processors.base_processor import (
    BaseMultimodalProcessor,
    MultimodalSpecialTokens,
)
Mick's avatar
Mick committed
12
13
14
15
16
17
18
19
20
21


# Compatible with both 'O' and 'V'
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
    models = [MiniCPMV, MiniCPMO]

    def __init__(self, hf_config, server_args, _processor):
        super().__init__(hf_config, server_args, _processor)
        self.image_token = "(<image>./</image>)"
        self.audio_token = "(<audio>./</audio>)"
22
        self.video_token = "(<video>./</video>)"
Mick's avatar
Mick committed
23
24
25
26

    async def process_mm_data_async(
        self,
        image_data: List[Union[str, bytes]],
27
        audio_data: List[Union[str, bytes]],
28
        input_text,
Mick's avatar
Mick committed
29
30
        request_obj,
        max_req_input_len,
31
        **kwargs,
Mick's avatar
Mick committed
32
33
    ):
        base_output = self.load_mm_data(
34
            prompt=input_text,
Mick's avatar
Mick committed
35
36
37
38
            max_req_input_len=max_req_input_len,
            audio_data=audio_data,
            image_data=image_data,
            multimodal_tokens=MultimodalSpecialTokens(
39
                image_token=self.image_token,
40
                video_token=self.video_token,
41
                audio_token=self.audio_token,
Mick's avatar
Mick committed
42
43
44
45
46
            ),
        )
        if base_output is None:
            return None

Mick's avatar
Mick committed
47
        res = self.process_mm_data(
Mick's avatar
Mick committed
48
            input_text=base_output.input_text,
Mick's avatar
Mick committed
49
            images=base_output.images,
Mick's avatar
Mick committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            audios=base_output.audios,
        )

        # Collect special token ids
        tokenizer = self._processor.tokenizer
        slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
            None,
            None,
            None,
            None,
        )
        if tokenizer.slice_start_id:
            slice_start_id = tokenizer.slice_start_id
            slice_end_id = tokenizer.slice_end_id
        if hasattr(tokenizer, "audio_start_id"):
            audio_start_id = tokenizer.audio_start_id
            audio_end_id = tokenizer.audio_end_id

68
69
        im_start_id = tokenizer.im_start_id
        im_end_id = tokenizer.im_end_id
70
        im_token_id = tokenizer.unk_id
Mick's avatar
Mick committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        pixel_values = res["pixel_values"]
        tgt_sizes = res["tgt_sizes"]

        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
            )

        if not isinstance(tgt_sizes, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
            )

        if len(pixel_values) != len(tgt_sizes):
            raise ValueError(
                "Inconsistent batch lengths, found: "
                f"{len(pixel_values)} vs. {len(tgt_sizes)}"
            )

        pixel_values_flat: List[torch.Tensor] = []
        tgt_sizes_flat: List[torch.Tensor] = []
        for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
            # per image
            if len(pixel_b) != len(tgt_b):
                raise ValueError(
                    "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
                )
            for pixel_n, tgt_n in zip(pixel_b, tgt_b):
                pixel_values_flat += [pixel_n]
                tgt_sizes_flat += [tgt_n]

        pixel_values = pixel_values_flat
Mick's avatar
Mick committed
103
104

        items = []
105
106
107
108
109
110
111
112
113
114
        input_ids = res["input_ids"].flatten()
        image_offsets = self.get_mm_items_offset_by_pair(
            input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
        )
        slice_offsets = self.get_mm_items_offset_by_pair(
            input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
        )
        image_offsets.extend(slice_offsets)
        image_offsets = sorted(image_offsets)

Mick's avatar
Mick committed
115
116
117
        if len(pixel_values) != 0:
            item = MultimodalDataItem(
                pixel_values=pixel_values,
118
                offsets=image_offsets,
Mick's avatar
Mick committed
119
120
121
122
123
124
125
126
127
128
                tgt_size=tgt_sizes_flat,
                modality=Modality.IMAGE,
            )
            items += [item]

        if (
            "audio_features" in res
            and res["audio_features"] is not None
            and len(res["audio_features"]) != 0
        ):
129
130
131
132
133
134
135
136
            if audio_start_id is not None and audio_end_id is not None:
                audio_offsets = self.get_mm_items_offset_by_pair(
                    input_ids=input_ids,
                    mm_start_id=audio_start_id,
                    mm_end_id=audio_end_id,
                )
            else:
                audio_offsets = None
Mick's avatar
Mick committed
137
138
139
            item = MultimodalDataItem(
                audio_features=[res["audio_features"]],
                audio_feature_lens=res["audio_feature_lens"],
140
                offsets=audio_offsets,
Mick's avatar
Mick committed
141
142
143
                modality=Modality.AUDIO,
            )
            items += [item]
Mick's avatar
Mick committed
144
        return {
Mick's avatar
Mick committed
145
            "mm_items": items,
146
            "input_ids": input_ids.tolist(),
Mick's avatar
Mick committed
147
148
149
            "audio_start_id": audio_start_id,
            "audio_end_id": audio_end_id,
            "im_token_id": im_token_id,
150
151
            "im_start_id": im_start_id,
            "im_end_id": im_end_id,
Mick's avatar
Mick committed
152
153
154
            "slice_start_id": slice_start_id,
            "slice_end_id": slice_end_id,
        }