pali_gemma.py 2.54 KB
Newer Older
Daniël de Kok's avatar
Daniël de Kok committed
1
2
from io import BytesIO
from PIL import Image
drbh's avatar
drbh committed
3
4
5
import torch
import torch.distributed
from opentelemetry import trace
6
from typing import Iterable
drbh's avatar
drbh committed
7
8
9
10
from text_generation_server.models.vlm_causal_lm import (
    VlmCausalLMBatch,
    image_text_replacement,
)
Daniël de Kok's avatar
Daniël de Kok committed
11
12

from text_generation_server.pb.generate_pb2 import Request
drbh's avatar
drbh committed
13
14
15
16
17
18

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
19
20
21
    def batch_tokenized_inputs(
        cls, requests: Iterable[Request], tokenizer, processor, config
    ):
drbh's avatar
drbh committed
22
23
24
25
26
27
        batch_inputs = []
        image_inputs = []
        max_truncation = 0
        for r in requests:
            full_text = ""
            image_id = 0
Daniël de Kok's avatar
Daniël de Kok committed
28
29
30
31
32
33
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    full_text += "<bos>" + chunk.text + "\n"
                elif chunk_type == "image":
                    image = Image.open(BytesIO(chunk.image.data))
drbh's avatar
drbh committed
34
35
36
                    # TODO do_convert_RGB should be on by default ?
                    image = image.convert("RGB")
                    image_input = processor.image_processor(image, return_tensors="pt")
37
38
39
                    full_text += image_text_replacement(
                        processor, image_input, config, image_id
                    )
drbh's avatar
drbh committed
40
41
                    image_inputs.append(image_input)
                else:
Daniël de Kok's avatar
Daniël de Kok committed
42
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")
drbh's avatar
drbh committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=False,
        )["input_ids"]
        if image_inputs:
            image_input = image_inputs[0]
            new_image_inputs = {
                "pixel_values": torch.cat(
                    [img["pixel_values"] for img in image_inputs], dim=0
                ),
            }
            if "pixel_attention_mask" in image_input:
                new_image_inputs["pixel_attention_mask"] = torch.cat(
                    [img["pixel_attention_mask"] for img in image_inputs], dim=0
                )
            if "image_sizes" in image_input:
                new_image_inputs["image_sizes"] = torch.cat(
                    [img["image_sizes"] for img in image_inputs], dim=0
                )
            image_inputs = new_image_inputs
        else:
            image_inputs = None
        return batch_tokenized_inputs, image_inputs