"vscode:/vscode.git/clone" did not exist on "8cc27fdc4631ebda34d4247f2c8dd3cd32152f13"
vlm_causal_lm.py 14.6 KB
Newer Older
1
2
3
4
5
import torch
from PIL import Image
from io import BytesIO

from opentelemetry import trace
Daniël de Kok's avatar
Daniël de Kok committed
6
from typing import Iterable, Optional, Tuple, List, Type, Dict
7
8
9
10

from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
11
12
13
from text_generation_server.models.flash_causal_lm import (
    FlashCausalLMBatch,
    FlashCausalLM,
14
)
15
from text_generation_server.utils.log import log_master
16
from transformers import AutoProcessor
Wang, Yi's avatar
Wang, Yi committed
17
from text_generation_server.layers.attention import Seqlen
18
19
20

tracer = trace.get_tracer(__name__)

21
22
23
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"

24
25
26
27
28
29
30

def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
    """
    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.

    Args:
        image_size (`tuple`):
31
            The size of the input image in the format (height, width).
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        grid_pinpoints (`List`):
            A list containing possible resolutions. Each item in the list should be a tuple or list
            of the form `(height, width)`.
        patch_size (`int`):
            The size of each image patch.

    Returns:
        tuple: The shape of the image patch grid in the format (width, height).
    """
    if not isinstance(grid_pinpoints, list):
        raise ValueError("grid_pinpoints should be a list of tuples or lists")

    height, width = select_best_resolution(image_size, grid_pinpoints)
    return height // patch_size, width // patch_size


48
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
Nicolas Patry's avatar
Nicolas Patry committed
49
    if config.model_type == "idefics2":
50
51
52
53
54
        image_seq_len = 64
        image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
        if processor.image_processor.do_image_splitting:
            image_str *= 5
        return image_str
Nicolas Patry's avatar
Nicolas Patry committed
55
56
57
58
59
    elif config.model_type == "llava_next":
        height, width = image_input["image_sizes"][image_id]
        num_features = get_number_of_features(height, width, config)
        from loguru import logger

60
61
62
        log_master(
            logger.info,
            f"Found {num_features} features in image of resolution {height}x{width}",
63
        )
Nicolas Patry's avatar
Nicolas Patry committed
64
        return "<image>" * num_features
drbh's avatar
drbh committed
65
66
67

    elif config.model_type == "paligemma":
        return "<image>" * config.text_config.num_image_tokens
Nicolas Patry's avatar
Nicolas Patry committed
68
69
70
71
    else:
        raise RuntimeError(f"Unknown config {config.model_type} for multimodal")


72
73
74
75
76
77
78
79
def image_text_replacement_fixup(config, text: str) -> str:
    if config.model_type == "idefics2":
        return text.replace(
            f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
        )
    return text


Nicolas Patry's avatar
Nicolas Patry committed
80
def get_unpadded_features(
81
82
83
84
85
    original_height: int,
    original_width: int,
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
Nicolas Patry's avatar
Nicolas Patry committed
86
87
88
89
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

90
    aspect_ratio: float = original_width / original_height
Nicolas Patry's avatar
Nicolas Patry committed
91
    current_aspect_ratio: float = current_width / current_height
92

Nicolas Patry's avatar
Nicolas Patry committed
93
    if aspect_ratio > current_aspect_ratio:
94
95
96
        new_height = (original_height * current_width) // original_width
        padding = (current_height - new_height) // 2
        current_height = current_height - (2 * padding)
Nicolas Patry's avatar
Nicolas Patry committed
97
    else:
98
99
100
        new_width = (original_width * current_height) // original_height
        padding = (current_width - new_width) // 2
        current_width = current_width - (2 * padding)
Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
104
105
106

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


107
108
109
110
111
112
113
114
115
116
117
118
def get_number_of_features(height: int, width: int, config) -> int:
    # From config
    # Hardcoded for CLIP for now
    # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
    image_grid_pinpoints = config.image_grid_pinpoints
    image_size = config.vision_config.image_size
    patch_size = config.vision_config.patch_size

    assert image_size % patch_size == 0

    npatches = image_size // patch_size

119
120
121
    # Dimensions are intentionally swapped to be bug-compatible with
    # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
122
123
124
125
        [height, width],
        image_grid_pinpoints,
        image_size,
    )
Nicolas Patry's avatar
Nicolas Patry committed
126
127
128
    unpadded_features, newline_features = get_unpadded_features(
        height, width, npatches, num_patch_height, num_patch_width
    )
129
130
131
132
133
    # The base patch covers the entire image
    base_features = npatches**2
    return unpadded_features + newline_features + base_features


134
class VlmCausalLMBatch(FlashCausalLMBatch):
135
    pixel_values: Optional[List[torch.Tensor]]
Nicolas Patry's avatar
Nicolas Patry committed
136
    pixel_attention_mask: Optional[List[torch.Tensor]]
137
138
139
140
141
142
143
    image_sizes: Optional[List[Tuple[int, int]]]

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches):
        batch = super(VlmCausalLMBatch, cls).concatenate(batches)
        batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
144
        batch.pixel_attention_mask = None
145
146
147
148
149
150
151
        batch.image_sizes = None
        return batch

    @tracer.start_as_current_span("filter")
    def filter(self, request_ids: List[int]):
        batch = super().filter(request_ids)
        batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
152
        batch.pixel_attention_mask = None
153
154
155
156
        batch.image_sizes = None
        return batch

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
157
158
159
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
    ):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        # Process images first. We need all of them so that the processor
        # can make the image splits the same size. And we need the final
        # sizes to insert correct number of image tokens.
        images = []
        for r in requests:
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    pass
                elif chunk_type == "image":
                    image = Image.open(BytesIO(chunk.image.data))
                    if config.model_type == "llava_next":
                        images.append(image)
                    else:
                        images.append([image])
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")

        if images:
            image_inputs = processor.image_processor(images, return_tensors="pt")
        else:
            image_inputs = None

183
184
        batch_inputs = []
        max_truncation = 0
185
        image_id = 0
186
187
        for r in requests:
            full_text = ""
Daniël de Kok's avatar
Daniël de Kok committed
188
189
190
191
192
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    full_text += chunk.text
                elif chunk_type == "image":
193
194
195
                    full_text += image_text_replacement(
                        processor, image_inputs, config, image_id
                    )
196
                    image_id += 1
197

198
199
            full_text = image_text_replacement_fixup(config, full_text)

200
201
202
203
            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
drbh's avatar
drbh committed
204
205
206
207
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=not config.model_type == "paligemma",
208
        )["input_ids"]
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return batch_tokenized_inputs, image_inputs

    @classmethod
    def from_pb_processor(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        processor,
        config,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "VlmCausalLMBatch":
        batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
            pb.requests, tokenizer, processor, config
        )
        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
        if image_inputs is not None:
            batch.pixel_values = image_inputs["pixel_values"].to(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
228
229
230
231
232
233
234
235
236
237
            if "pixel_attention_mask" in image_inputs:
                batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
                    device=device
                )
            else:
                batch.pixel_attention_mask = None
            if "image_sizes" in image_inputs:
                batch.image_sizes = image_inputs["image_sizes"].to(device=device)
            else:
                batch.image_sizes = None
238
239
        else:
            batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
240
            batch.pixel_attention_mask = None
241
242
243
244
            batch.image_sizes = None
        return batch


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class VlmCausalLM(FlashCausalLM):
    def __init__(
        self,
        model_id: str,
        *,
        processor_class=AutoProcessor,
        processor_kwargs=None,
        batch_class=VlmCausalLMBatch,
        revision,
        trust_remote_code: bool,
        **kwargs,
    ):
        if processor_kwargs is None:
            processor_kwargs = {}
        self.processor = processor_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
            **processor_kwargs,
        )
        self.batch_class = batch_class
266
267
268
269
270
271
        super().__init__(
            model_id=model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
272

273
274
    @property
    def batch_type(self) -> Type[VlmCausalLMBatch]:
275
276
277
278
        return self.batch_class

    def max_past(self) -> Optional[int]:
        return getattr(self.model.text_model, "max_past", None)
279
280

    def forward(
drbh's avatar
drbh committed
281
282
283
        self,
        batch: VlmCausalLMBatch,
        adapter_data: Optional[Dict[str, torch.Tensor]] = None,
284
285
286
287
288
289
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Model Forward
        if batch.speculative_ids is not None:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
290
            kv_cache = self.kv_cache
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices

            speculative_ids = batch.speculative_ids

            B, speculative_length = speculative_ids.shape
            new_length = speculative_length + 1
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)

            # Add Copy the block tables for all members
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
329
            kv_cache = self.kv_cache
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices

        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

        bs = input_ids.shape[0]
        # Try to find an associated cuda graph
Nicolas Patry's avatar
Nicolas Patry committed
344
345
346
347
348
349
350
        bs = input_ids.shape[0]
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None
351
        if cu_seqlen_prefill is not None or cuda_graph is None:
Wang, Yi's avatar
Wang, Yi committed
352
            input_lengths = Seqlen(input_lengths=input_lengths)
353
354
355
356
357
358
359
360
361
362
363
364
            logits, speculative_logits = self.model.forward(
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_s=max_s,
                prefill_cache_indices=batch.prefill_cache_indices,
                lm_head_indices=lm_head_indices,
                pixel_values=batch.pixel_values,
Nicolas Patry's avatar
Nicolas Patry committed
365
                pixel_attention_mask=batch.pixel_attention_mask,
366
367
368
369
370
371
                image_sizes=batch.image_sizes,
            )
            if batch.prefill_cache_indices is not None:
                batch.prefill_cache_indices = None
            if batch.pixel_values is not None:
                batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
372
373
            if batch.pixel_attention_mask is not None:
                batch.pixel_attention_mask = None
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            if batch.image_sizes is not None:
                batch.image_sizes = None
            return logits, speculative_logits

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
        cuda_graph["block_tables"][
            : block_tables.shape[0], : block_tables.shape[1]
        ] = block_tables
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths

        # Replay the graph
        cuda_graph["graph"].replay()

        # Slice output to the correct shape
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits