pali_gemma.py 3.94 KB
Newer Older
Daniël de Kok's avatar
Daniël de Kok committed
1
2
from io import BytesIO
from PIL import Image
drbh's avatar
drbh committed
3
4
5
import torch
import torch.distributed
from opentelemetry import trace
Daniël de Kok's avatar
Daniël de Kok committed
6
from typing import Iterable, Optional, Tuple
drbh's avatar
drbh committed
7
8
9
10
11
12
13
14
from text_generation_server.models.vlm_causal_lm import (
    VlmCausalLM,
    VlmCausalLMBatch,
    image_text_replacement,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
    PaliGemmaForConditionalGeneration,
)
Daniël de Kok's avatar
Daniël de Kok committed
15
16
17
from transformers import AutoProcessor, AutoConfig

from text_generation_server.pb.generate_pb2 import Request
drbh's avatar
drbh committed
18
19
20
21
22
23

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
    @classmethod
Daniël de Kok's avatar
Daniël de Kok committed
24
25
26
    def batch_tokenized_inputs(
        cls, requests: Iterable[Request], tokenizer, processor, config
    ):
drbh's avatar
drbh committed
27
28
29
30
31
32
        batch_inputs = []
        image_inputs = []
        max_truncation = 0
        for r in requests:
            full_text = ""
            image_id = 0
Daniël de Kok's avatar
Daniël de Kok committed
33
34
35
36
37
38
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    full_text += "<bos>" + chunk.text + "\n"
                elif chunk_type == "image":
                    image = Image.open(BytesIO(chunk.image.data))
drbh's avatar
drbh committed
39
40
41
                    # TODO do_convert_RGB should be on by default ?
                    image = image.convert("RGB")
                    image_input = processor.image_processor(image, return_tensors="pt")
42
43
44
                    full_text += image_text_replacement(
                        processor, image_input, config, image_id
                    )
drbh's avatar
drbh committed
45
46
                    image_inputs.append(image_input)
                else:
Daniël de Kok's avatar
Daniël de Kok committed
47
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")
drbh's avatar
drbh committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=False,
        )["input_ids"]
        if image_inputs:
            image_input = image_inputs[0]
            new_image_inputs = {
                "pixel_values": torch.cat(
                    [img["pixel_values"] for img in image_inputs], dim=0
                ),
            }
            if "pixel_attention_mask" in image_input:
                new_image_inputs["pixel_attention_mask"] = torch.cat(
                    [img["pixel_attention_mask"] for img in image_inputs], dim=0
                )
            if "image_sizes" in image_input:
                new_image_inputs["image_sizes"] = torch.cat(
                    [img["image_sizes"] for img in image_inputs], dim=0
                )
            image_inputs = new_image_inputs
        else:
            image_inputs = None
        return batch_tokenized_inputs, image_inputs


class PaliGemma(VlmCausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        self.processor = AutoProcessor.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

        super().__init__(
            config_cls=AutoConfig,
            model_cls=PaliGemmaForConditionalGeneration,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

    @property
    def batch_type(self):
        return PaliGemmaBatch

    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.text_model.model.layers),
            model.text_model.model.num_key_value_heads,
            model.text_model.model.head_size,
        )

    def max_past(self) -> Optional[int]:
        return getattr(self.model.text_model, "max_past", None)