internvl.py 9.62 KB
Newer Older
xm:D's avatar
xm:D committed
1
2
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py

3
4
from functools import lru_cache

xm:D's avatar
xm:D committed
5
6
import numpy as np
import torch
7
8
import torchvision.transforms as T
from decord import VideoReader, cpu, gpu
xm:D's avatar
xm:D committed
9
from PIL import Image
10
from torchvision.transforms import InterpolationMode
xm:D's avatar
xm:D committed
11

12
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
RunningLeon's avatar
RunningLeon committed
13
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
14
15
from sglang.srt.models.internvl import InternVLChatModel
from sglang.srt.multimodal.processors.base_processor import (
xm:D's avatar
xm:D committed
16
17
18
19
20
21
    BaseMultimodalProcessor,
    MultimodalSpecialTokens,
)


class InternVLImageProcessor(BaseMultimodalProcessor):
RunningLeon's avatar
RunningLeon committed
22
    models = [InternVLChatModel, InternS1ForConditionalGeneration]
xm:D's avatar
xm:D committed
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]

    @staticmethod
    @lru_cache(maxsize=1)
    def _get_normalize_tensors(device="cuda", dtype=torch.float32):
        mean = torch.tensor(
            InternVLImageProcessor.IMAGENET_MEAN, device=device, dtype=dtype
        ).view(-1, 1, 1)
        std = torch.tensor(
            InternVLImageProcessor.IMAGENET_STD, device=device, dtype=dtype
        ).view(-1, 1, 1)
        return mean, std

Mick's avatar
Mick committed
38
39
    def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
        super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
RunningLeon's avatar
RunningLeon committed
40
41
42
43
        image_size = (
            getattr(hf_config, "force_image_size", None)
            or hf_config.vision_config.image_size
        )
xm:D's avatar
xm:D committed
44
        patch_size = hf_config.vision_config.patch_size
RunningLeon's avatar
RunningLeon committed
45
46
47
48
        if isinstance(image_size, list):
            image_size = image_size[0]
        if isinstance(patch_size, list):
            patch_size = patch_size[0]
xm:D's avatar
xm:D committed
49
50
51
52
53
54
55

        self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
        self.IMG_START_TOKEN = "<img>"
        self.IMG_END_TOKEN = "</img>"
        self.num_image_token = int(
            (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
        )
RunningLeon's avatar
RunningLeon committed
56
57
58
59
60
        if hasattr(self._processor, "tokenizer"):
            tokenizer = self._processor.tokenizer
        else:
            tokenizer = self._processor
        self.tokenizer = tokenizer
xm:D's avatar
xm:D committed
61
62
63

        self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
        self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
64
        self.mm_tokens = MultimodalSpecialTokens(
65
            image_token="<IMG_CONTEXT>",
66
67
            image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
        ).build(_image_processor)
xm:D's avatar
xm:D committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    @staticmethod
    def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
        if bound:
            start, end = bound[0], bound[1]
        else:
            start, end = -100000, 100000
        start_idx = max(first_idx, round(start * fps))
        end_idx = min(round(end * fps), max_frame)
        seg_size = float(end_idx - start_idx) / num_segments
        frame_indices = np.array(
            [
                int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
                for idx in range(num_segments)
            ]
        )
        return frame_indices

    @staticmethod
    def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
88
89
90
91
92
93
94
95
96
97
        try:
            vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
            use_gpu = True
        except (RuntimeError, OSError) as e:
            print(
                f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
            )
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
            use_gpu = False

xm:D's avatar
xm:D committed
98
        max_frame = len(vr) - 1
99
        fps = float(vr.get_avg_fps())
xm:D's avatar
xm:D committed
100

101
102
        pixel_values_list = []
        num_patches_list = []
xm:D's avatar
xm:D committed
103
104
105
        frame_indices = InternVLImageProcessor.get_index(
            bound, fps, max_frame, first_idx=0, num_segments=num_segments
        )
106

107
108
        mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")

xm:D's avatar
xm:D committed
109
        for frame_index in frame_indices:
110
111
112
113
114
115
116
117
118
119
120
121
            # Load frame
            frame = vr[frame_index]
            if use_gpu:
                img = frame.cuda().permute(2, 0, 1).float() / 255.0
            else:
                img_np = frame.asnumpy()
                img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0

            img = (img - mean) / std

            tiles = InternVLImageProcessor.dynamic_preprocess(
                img, image_size=input_size, max_num=max_num, use_thumbnail=True
xm:D's avatar
xm:D committed
122
            )
123
124
125
126
127

            pixel_values_list.append(tiles)
            num_patches_list.append(tiles.shape[0])

        pixel_values = torch.cat(pixel_values_list, dim=0)
xm:D's avatar
xm:D committed
128
129
        return pixel_values, num_patches_list

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    @staticmethod
    def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
        C, H, W = tensor.shape
        aspect_ratio = W / H

        # Generate all possible aspect ratios
        target_ratios = set(
            (i, j)
            for n in range(1, max_num + 1)
            for i in range(1, n + 1)
            for j in range(1, n + 1)
            if i * j <= max_num
        )
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # Find closest ratio
        best_ratio_diff = float("inf")
        best_ratio = (1, 1)

        for x, y in target_ratios:
            target_ar = x / y
            diff = abs(aspect_ratio - target_ar)
            blocks = x * y
            best_blocks = best_ratio[0] * best_ratio[1]

            if diff < best_ratio_diff:
                best_ratio_diff = diff
                best_ratio = (x, y)
            elif diff == best_ratio_diff and blocks > best_blocks:
                best_ratio = (x, y)

        target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
        blocks = best_ratio[0] * best_ratio[1]

        # Resize on GPU
        resized = torch.nn.functional.interpolate(
            tensor.unsqueeze(0),
            size=(target_h, target_w),
            mode="bicubic",
            align_corners=False,
        ).squeeze(0)

        # Split into tiles
        tiles = []
        for i in range(blocks):
            x = (i % best_ratio[0]) * image_size
            y = (i // best_ratio[0]) * image_size
            tile = resized[:, y : y + image_size, x : x + image_size]
            tiles.append(tile)

        # Add thumbnail if needed
        if use_thumbnail and len(tiles) > 1:
            thumb = torch.nn.functional.interpolate(
                tensor.unsqueeze(0),
                size=(image_size, image_size),
                mode="bicubic",
                align_corners=False,
            ).squeeze(0)
            tiles.append(thumb)

        return torch.stack(tiles).to(torch.bfloat16)

xm:D's avatar
xm:D committed
192
    async def process_mm_data_async(
193
        self, image_data, input_text, request_obj, **kwargs
xm:D's avatar
xm:D committed
194
195
196
197
    ):
        base_output = self.load_mm_data(
            prompt=input_text,
            image_data=image_data,
198
            multimodal_tokens=self.mm_tokens,
xm:D's avatar
xm:D committed
199
200
201
202
203
            discard_alpha_channel=True,
        )

        num_patches_list = []
        pixel_values = []
204

205
206
        mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")

xm:D's avatar
xm:D committed
207
        # Process each input with allocated frames
208
        for image_index, image in enumerate(base_output.images):
xm:D's avatar
xm:D committed
209
210
            try:
                # TODO: video input
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                # Convert PIL to GPU tensor
                if isinstance(image, Image.Image):
                    img_np = np.array(image.convert("RGB"))
                    tensor = (
                        torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
                    )
                else:
                    tensor = image.cuda()  # assume already tensor

                tensor = (tensor - mean) / std
                tiles = self.dynamic_preprocess(
                    tensor, image_size=448, max_num=12, use_thumbnail=True
                )

                pixel_values.append(tiles)
                num_patches_list.append(tiles.shape[0])

            except Exception as e:
                print(f"[Error] Failed to process image {image_index}: {e}")
xm:D's avatar
xm:D committed
230
231
                return None

232
        # Concatenate all
xm:D's avatar
xm:D committed
233
234
        pixel_values = torch.cat(pixel_values, dim=0)

235
236
237
        original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
        input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)

238
239
        input_text_updated = input_text
        for num_patches in num_patches_list:
xm:D's avatar
xm:D committed
240
241
242
243
244
            image_tokens = (
                self.IMG_START_TOKEN
                + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
                + self.IMG_END_TOKEN
            )
245
246
247
            input_text_updated = input_text_updated.replace(
                original_placeholder, image_tokens, 1
            )
248

249
250
251
        input_text_updated = input_text_updated.replace(
            original_placeholder, self.IMG_CONTEXT_TOKEN
        )
xm:D's avatar
xm:D committed
252

253
254
        # Tokenize
        input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
RunningLeon's avatar
RunningLeon committed
255
256
            "input_ids"
        ].flatten()
257
258
259
        input_ids = input_ids_tensor.tolist()

        # Get image token offsets
260
        image_offsets = self.get_mm_items_offset(
261
            input_ids=input_ids_tensor.to("cuda"),
262
            mm_token_id=self.mm_tokens.image_token_id,
263
        )
264

265
266
        items = [
            MultimodalDataItem(
267
                feature=pixel_values,
268
                modality=Modality.IMAGE,
269
                offsets=image_offsets,
270
271
272
            )
        ]

xm:D's avatar
xm:D committed
273
        return {
274
            "input_ids": input_ids,
xm:D's avatar
xm:D committed
275
276
277
            "mm_items": items,
            "im_start_id": self.img_start_token_id,
            "im_end_id": self.img_end_token_id,
278
            "im_token_id": self.mm_tokens.image_token_id,
xm:D's avatar
xm:D committed
279
        }