vlm_causal_lm.py 16.4 KB
Newer Older
1
2
3
4
5
import torch
from PIL import Image
from io import BytesIO

from opentelemetry import trace
Daniël de Kok's avatar
Daniël de Kok committed
6
from typing import Iterable, Optional, Tuple, List, Type, Dict
7
8
9
10

from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
11
12
13
from text_generation_server.models.flash_causal_lm import (
    FlashCausalLMBatch,
    FlashCausalLM,
Nicolas Patry's avatar
Nicolas Patry committed
14
    block_tables_to_ragged,
15
)
Nicolas Patry's avatar
Nicolas Patry committed
16
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
17
from text_generation_server.utils.log import log_master
18
from transformers import AutoProcessor
Wang, Yi's avatar
Wang, Yi committed
19
from text_generation_server.layers.attention import Seqlen
20
21
22

tracer = trace.get_tracer(__name__)

23
24
25
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"

26
27
28
29
30
31
32

def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
    """
    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.

    Args:
        image_size (`tuple`):
33
            The size of the input image in the format (height, width).
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        grid_pinpoints (`List`):
            A list containing possible resolutions. Each item in the list should be a tuple or list
            of the form `(height, width)`.
        patch_size (`int`):
            The size of each image patch.

    Returns:
        tuple: The shape of the image patch grid in the format (width, height).
    """
    if not isinstance(grid_pinpoints, list):
        raise ValueError("grid_pinpoints should be a list of tuples or lists")

    height, width = select_best_resolution(image_size, grid_pinpoints)
    return height // patch_size, width // patch_size


50
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
Nicolas Patry's avatar
Nicolas Patry committed
51
    if config.model_type == "idefics2":
52
53
54
55
56
        image_seq_len = 64
        image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
        if processor.image_processor.do_image_splitting:
            image_str *= 5
        return image_str
Nicolas Patry's avatar
Nicolas Patry committed
57
58
59
60
61
    elif config.model_type == "llava_next":
        height, width = image_input["image_sizes"][image_id]
        num_features = get_number_of_features(height, width, config)
        from loguru import logger

62
63
64
        log_master(
            logger.info,
            f"Found {num_features} features in image of resolution {height}x{width}",
65
        )
Nicolas Patry's avatar
Nicolas Patry committed
66
        return "<image>" * num_features
drbh's avatar
drbh committed
67
68
69

    elif config.model_type == "paligemma":
        return "<image>" * config.text_config.num_image_tokens
Nicolas Patry's avatar
Nicolas Patry committed
70
71
72
73
    else:
        raise RuntimeError(f"Unknown config {config.model_type} for multimodal")


74
75
76
77
78
79
80
81
def image_text_replacement_fixup(config, text: str) -> str:
    if config.model_type == "idefics2":
        return text.replace(
            f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
        )
    return text


Nicolas Patry's avatar
Nicolas Patry committed
82
def get_unpadded_features(
83
84
85
86
87
    original_height: int,
    original_width: int,
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
Nicolas Patry's avatar
Nicolas Patry committed
88
89
90
91
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

92
    aspect_ratio: float = original_width / original_height
Nicolas Patry's avatar
Nicolas Patry committed
93
    current_aspect_ratio: float = current_width / current_height
94

Nicolas Patry's avatar
Nicolas Patry committed
95
    if aspect_ratio > current_aspect_ratio:
96
97
98
        new_height = (original_height * current_width) // original_width
        padding = (current_height - new_height) // 2
        current_height = current_height - (2 * padding)
Nicolas Patry's avatar
Nicolas Patry committed
99
    else:
100
101
102
        new_width = (original_width * current_height) // original_height
        padding = (current_width - new_width) // 2
        current_width = current_width - (2 * padding)
Nicolas Patry's avatar
Nicolas Patry committed
103
104
105
106
107
108

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


109
110
111
112
113
114
115
116
117
118
119
120
def get_number_of_features(height: int, width: int, config) -> int:
    # From config
    # Hardcoded for CLIP for now
    # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
    image_grid_pinpoints = config.image_grid_pinpoints
    image_size = config.vision_config.image_size
    patch_size = config.vision_config.patch_size

    assert image_size % patch_size == 0

    npatches = image_size // patch_size

121
122
123
    # Dimensions are intentionally swapped to be bug-compatible with
    # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
124
125
126
127
        [height, width],
        image_grid_pinpoints,
        image_size,
    )
Nicolas Patry's avatar
Nicolas Patry committed
128
129
130
    unpadded_features, newline_features = get_unpadded_features(
        height, width, npatches, num_patch_height, num_patch_width
    )
131
132
133
134
135
    # The base patch covers the entire image
    base_features = npatches**2
    return unpadded_features + newline_features + base_features


136
class VlmCausalLMBatch(FlashCausalLMBatch):
137
    pixel_values: Optional[List[torch.Tensor]]
Nicolas Patry's avatar
Nicolas Patry committed
138
    pixel_attention_mask: Optional[List[torch.Tensor]]
139
140
141
142
143
144
145
    image_sizes: Optional[List[Tuple[int, int]]]

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches):
        batch = super(VlmCausalLMBatch, cls).concatenate(batches)
        batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
146
        batch.pixel_attention_mask = None
147
148
149
150
151
152
153
        batch.image_sizes = None
        return batch

    @tracer.start_as_current_span("filter")
    def filter(self, request_ids: List[int]):
        batch = super().filter(request_ids)
        batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
154
        batch.pixel_attention_mask = None
155
156
157
158
        batch.image_sizes = None
        return batch

    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
159
160
161
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
    ):
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        # Process images first. We need all of them so that the processor
        # can make the image splits the same size. And we need the final
        # sizes to insert correct number of image tokens.
        images = []
        for r in requests:
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    pass
                elif chunk_type == "image":
                    image = Image.open(BytesIO(chunk.image.data))
                    if config.model_type == "llava_next":
                        images.append(image)
                    else:
                        images.append([image])
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")

        if images:
            image_inputs = processor.image_processor(images, return_tensors="pt")
        else:
            image_inputs = None

185
186
        batch_inputs = []
        max_truncation = 0
187
        image_id = 0
188
189
        for r in requests:
            full_text = ""
Daniël de Kok's avatar
Daniël de Kok committed
190
191
192
193
194
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    full_text += chunk.text
                elif chunk_type == "image":
195
196
197
                    full_text += image_text_replacement(
                        processor, image_inputs, config, image_id
                    )
198
                    image_id += 1
199

200
201
            full_text = image_text_replacement_fixup(config, full_text)

202
203
204
205
            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
drbh's avatar
drbh committed
206
207
208
209
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=not config.model_type == "paligemma",
210
        )["input_ids"]
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        return batch_tokenized_inputs, image_inputs

    @classmethod
    def from_pb_processor(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        processor,
        config,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "VlmCausalLMBatch":
        batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
            pb.requests, tokenizer, processor, config
        )
        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
        if image_inputs is not None:
            batch.pixel_values = image_inputs["pixel_values"].to(device=device)
Nicolas Patry's avatar
Nicolas Patry committed
230
231
232
233
234
235
236
237
238
239
            if "pixel_attention_mask" in image_inputs:
                batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
                    device=device
                )
            else:
                batch.pixel_attention_mask = None
            if "image_sizes" in image_inputs:
                batch.image_sizes = image_inputs["image_sizes"].to(device=device)
            else:
                batch.image_sizes = None
240
241
        else:
            batch.pixel_values = None
Nicolas Patry's avatar
Nicolas Patry committed
242
            batch.pixel_attention_mask = None
243
244
245
246
            batch.image_sizes = None
        return batch


247
248
249
250
251
252
253
254
255
256
257
258
class VlmCausalLM(FlashCausalLM):
    def __init__(
        self,
        model_id: str,
        *,
        processor_class=AutoProcessor,
        processor_kwargs=None,
        batch_class=VlmCausalLMBatch,
        revision,
        trust_remote_code: bool,
        **kwargs,
    ):
Nicolas Patry's avatar
Nicolas Patry committed
259
260
        if PREFIX_CACHING:
            raise NotImplementedError("Vlm do not work with prefix caching yet")
261
262
263
264
265
266
267
268
269
        if processor_kwargs is None:
            processor_kwargs = {}
        self.processor = processor_class.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
            **processor_kwargs,
        )
        self.batch_class = batch_class
270
271
272
273
274
275
        super().__init__(
            model_id=model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
276

277
278
    @property
    def batch_type(self) -> Type[VlmCausalLMBatch]:
279
280
281
282
        return self.batch_class

    def max_past(self) -> Optional[int]:
        return getattr(self.model.text_model, "max_past", None)
283
284

    def forward(
drbh's avatar
drbh committed
285
286
287
        self,
        batch: VlmCausalLMBatch,
        adapter_data: Optional[Dict[str, torch.Tensor]] = None,
288
289
290
291
292
293
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Model Forward
        if batch.speculative_ids is not None:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
294
            kv_cache = self.kv_cache
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices

            speculative_ids = batch.speculative_ids

            B, speculative_length = speculative_ids.shape
            new_length = speculative_length + 1
            new_input_ids = torch.cat(
                [input_ids.unsqueeze(-1), speculative_ids], dim=1
            ).reshape(-1)
            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
            arange_int = arange.to(dtype=torch.int32)
            new_position_ids = (
                position_ids.unsqueeze(-1).expand(B, new_length) + arange
            ).view(-1)
            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
            input_lengths = (
                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
            ).view(-1)
Nicolas Patry's avatar
Nicolas Patry committed
317
318
319
            prefix_lens_tensor = (
                batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
            ).reshape(-1)
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

            # Add Copy the block tables for all members
            block_tables = (
                block_tables.unsqueeze(1)
                .expand(B, new_length, -1)
                .reshape(B * new_length, -1)
                .contiguous()
            )
            max_s = max_s + speculative_length

            input_ids = new_input_ids
            position_ids = new_position_ids
        else:
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
336
            kv_cache = self.kv_cache
337
338
339
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
Nicolas Patry's avatar
Nicolas Patry committed
340
            prefix_lens_tensor = batch.prefix_lens_tensor
341
342
343
344
345
346
347
348
349
350
351
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices

        if cu_seqlen_prefill is None and self.max_past() is not None:
            # In decode, not prefill, we're actually overwriting the KV-cache
            # in a circular buffer mode.
            # This makes sure the max_s for the decode pass is correct.
            max_s = min(self.max_past(), max_s)

        bs = input_ids.shape[0]
        # Try to find an associated cuda graph
Nicolas Patry's avatar
Nicolas Patry committed
352
353
354
355
356
357
358
        bs = input_ids.shape[0]
        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
        if sorted_padded_bs:
            # Get associated cuda graph
            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
        else:
            cuda_graph = None
359
        if cu_seqlen_prefill is not None or cuda_graph is None:
Nicolas Patry's avatar
Nicolas Patry committed
360
361
362
363
364
365
366
367
            input_lengths = input_lengths + prefix_lens_tensor
            if PREFIX_CACHING:
                block_tables = block_tables_to_ragged(
                    block_tables=block_tables,
                    input_lengths=batch.input_lengths,
                    prefix_lens=batch.prefix_lens,
                )
            with self._forward_context(
368
                block_tables=block_tables,
Nicolas Patry's avatar
Nicolas Patry committed
369
370
371
372
373
374
                cu_seqlen_prefill=cu_seqlen_prefill,
                input_lengths=batch.input_lengths,
                input_lengths_tensor=input_lengths,
                prefix_lens=batch.prefix_lens,
                prefix_lens_tensor=prefix_lens_tensor,
            ):
375
376
377
378
379
380
381
382
                max_k = (input_lengths + prefix_lens_tensor).max().item()
                seqlen = Seqlen(
                    input_lengths=input_lengths,
                    prefix_lengths=prefix_lens_tensor,
                    cu_seqlen_q=cu_seqlen_prefill,
                    max_q=max_s,
                    max_k=max_k,
                )
Nicolas Patry's avatar
Nicolas Patry committed
383
384
385
386
387
388
389
                logits, speculative_logits = self.model.forward(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    cu_seqlen_prefill=cu_seqlen_prefill,
                    kv_cache=kv_cache,
                    block_tables=block_tables,
                    slots=slots,
390
                    seqlen=seqlen,
Nicolas Patry's avatar
Nicolas Patry committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                    max_s=max_s,
                    prefill_cache_indices=batch.prefill_cache_indices,
                    lm_head_indices=lm_head_indices,
                    pixel_values=batch.pixel_values,
                    pixel_attention_mask=batch.pixel_attention_mask,
                    image_sizes=batch.image_sizes,
                )
                if batch.prefill_cache_indices is not None:
                    batch.prefill_cache_indices = None
                if batch.pixel_values is not None:
                    batch.pixel_values = None
                if batch.pixel_attention_mask is not None:
                    batch.pixel_attention_mask = None
                if batch.image_sizes is not None:
                    batch.image_sizes = None
                return logits, speculative_logits
407
408
409
410
411

        # Copy inputs to the static inputs of the cuda graph
        # Static inputs are potentially padded
        cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
        cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
Nicolas Patry's avatar
Nicolas Patry committed
412
413
414
415
416
417
418
419
420
421
422
        if ATTENTION == "flashinfer":
            block_tables = block_tables_to_ragged(
                block_tables=block_tables,
                input_lengths=batch.input_lengths,
                prefix_lens=batch.prefix_lens,
            )
            cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
        else:
            cuda_graph["block_tables"][
                : block_tables.shape[0], : block_tables.shape[1]
            ] = block_tables
423
424
425
        cuda_graph["slots"].fill_(-1)
        cuda_graph["slots"][: slots.shape[0]] = slots
        cuda_graph["input_lengths"].zero_()
Nicolas Patry's avatar
Nicolas Patry committed
426
427
428
        cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
            input_lengths + prefix_lens_tensor
        )
429
430
431
432
433
434
435
436
437
438
439
440

        # Replay the graph
        cuda_graph["graph"].replay()

        # Slice output to the correct shape
        speculative_logits = (
            cuda_graph["speculative_logits"][:bs]
            if cuda_graph["speculative_logits"] is not None
            else None
        )
        logits = cuda_graph["logits"][:bs]
        return logits, speculative_logits