minicpm.py 5.57 KB
Newer Older
Mick's avatar
Mick committed
1
2
3
from typing import List, Union

import torch
Mick's avatar
Mick committed
4
from transformers import BaseImageProcessorFast
Mick's avatar
Mick committed
5
6
7
8
9

from sglang.srt.managers.multimodal_processors.base_processor import (
    BaseMultimodalProcessor,
    MultimodalSpecialTokens,
)
Mick's avatar
Mick committed
10
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
Mick's avatar
Mick committed
11
12
13
14
15
16
17
18
19
20
21
22
23
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV


# Compatible with both 'O' and 'V'
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
    models = [MiniCPMV, MiniCPMO]

    def __init__(self, hf_config, server_args, _processor):
        super().__init__(hf_config, server_args, _processor)
        self.image_token = "(<image>./</image>)"
        self.audio_token = "(<audio>./</audio>)"

Mick's avatar
Mick committed
24
    def process_data_task(self, input_text, images=None, audios=None):
Mick's avatar
Mick committed
25
26
27
28
29

        if isinstance(images, list) and len(images) == 0:
            images = None
        if isinstance(audios, list) and len(audios) == 0:
            audios = None
Mick's avatar
Mick committed
30
31
32
33
34
        processor = self._processor
        args = {}
        if isinstance(processor, BaseImageProcessorFast):
            args["device"] = "cuda"
        result = self._processor.__call__(
Mick's avatar
Mick committed
35
36
37
38
39
            text=input_text,
            images=images,
            audios=audios,
            return_tensors="pt",
            chunk_input=True,
Mick's avatar
Mick committed
40
            **args,
Mick's avatar
Mick committed
41
42
43
44
45
46
47
48
49
50
51
52
53
        )
        return {
            "input_ids": result.input_ids,
            "pixel_values": getattr(result, "pixel_values", None),
            "tgt_sizes": getattr(result, "tgt_sizes", None),
            "audio_features": getattr(result, "audio_features", None),
            "audio_feature_lens": getattr(result, "audio_feature_lens", None),
            "audio_bounds": getattr(result, "audio_bounds", None),
        }

    async def process_mm_data_async(
        self,
        image_data: List[Union[str, bytes]],
54
        input_text,
Mick's avatar
Mick committed
55
56
        request_obj,
        max_req_input_len,
57
        **kwargs,
Mick's avatar
Mick committed
58
59
60
61
62
63
64
65
66
67
    ):
        audio_data = request_obj.audio_data
        if not image_data and not audio_data:
            return None
        if not isinstance(image_data, list):
            image_data = [image_data]
        if not isinstance(audio_data, list):
            audio_data = [audio_data]

        base_output = self.load_mm_data(
68
            prompt=input_text,
Mick's avatar
Mick committed
69
70
71
72
73
74
75
76
77
78
            max_req_input_len=max_req_input_len,
            audio_data=audio_data,
            image_data=image_data,
            multimodal_tokens=MultimodalSpecialTokens(
                image_token=self.image_token, audio_token=self.audio_token
            ),
        )
        if base_output is None:
            return None

Mick's avatar
Mick committed
79
        res = self.process_mm_data(
Mick's avatar
Mick committed
80
            input_text=base_output.input_text,
Mick's avatar
Mick committed
81
            images=base_output.images,
Mick's avatar
Mick committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            audios=base_output.audios,
        )

        # Collect special token ids
        tokenizer = self._processor.tokenizer
        slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
            None,
            None,
            None,
            None,
        )
        if tokenizer.slice_start_id:
            slice_start_id = tokenizer.slice_start_id
            slice_end_id = tokenizer.slice_end_id
        if hasattr(tokenizer, "audio_start_id"):
            audio_start_id = tokenizer.audio_start_id
            audio_end_id = tokenizer.audio_end_id

        im_token_id = tokenizer.unk_token_id
        pixel_values = res["pixel_values"]
        tgt_sizes = res["tgt_sizes"]

        if not isinstance(pixel_values, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
            )

        if not isinstance(tgt_sizes, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
            )

        if len(pixel_values) != len(tgt_sizes):
            raise ValueError(
                "Inconsistent batch lengths, found: "
                f"{len(pixel_values)} vs. {len(tgt_sizes)}"
            )

        pixel_values_flat: List[torch.Tensor] = []
        tgt_sizes_flat: List[torch.Tensor] = []
        for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
            # per image
            if len(pixel_b) != len(tgt_b):
                raise ValueError(
                    "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
                )
            for pixel_n, tgt_n in zip(pixel_b, tgt_b):
                pixel_values_flat += [pixel_n]
                tgt_sizes_flat += [tgt_n]

        pixel_values = pixel_values_flat
Mick's avatar
Mick committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        items = []
        if len(pixel_values) != 0:
            item = MultimodalDataItem(
                pixel_values=pixel_values,
                tgt_size=tgt_sizes_flat,
                modality=Modality.IMAGE,
            )
            items += [item]

        if (
            "audio_features" in res
            and res["audio_features"] is not None
            and len(res["audio_features"]) != 0
        ):
            item = MultimodalDataItem(
                audio_features=[res["audio_features"]],
                audio_feature_lens=res["audio_feature_lens"],
                modality=Modality.AUDIO,
            )
            items += [item]

Mick's avatar
Mick committed
155
        return {
Mick's avatar
Mick committed
156
            "mm_items": items,
Mick's avatar
Mick committed
157
158
159
160
161
162
163
164
165
            "input_ids": res["input_ids"].flatten().tolist(),
            "audio_start_id": audio_start_id,
            "audio_end_id": audio_end_id,
            "im_token_id": im_token_id,
            "im_start_id": tokenizer.im_start_id,
            "im_end_id": tokenizer.im_end_id,
            "slice_start_id": slice_start_id,
            "slice_end_id": slice_end_id,
        }