pali_gemma.py 3.94 KB
Newer Older
xuxzh1's avatar
last  
xuxzh1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Iterable, Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
    VlmCausalLM,
    VlmCausalLMBatch,
    image_text_replacement,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
    PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig

from text_generation_server.pb.generate_pb2 import Request

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
    @classmethod
    def batch_tokenized_inputs(
        cls, requests: Iterable[Request], tokenizer, processor, config
    ):
        batch_inputs = []
        image_inputs = []
        max_truncation = 0
        for r in requests:
            full_text = ""
            image_id = 0
            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    full_text += "<bos>" + chunk.text + "\n"
                elif chunk_type == "image":
                    image = Image.open(BytesIO(chunk.image.data))
                    # TODO do_convert_RGB should be on by default ?
                    image = image.convert("RGB")
                    image_input = processor.image_processor(image, return_tensors="pt")
                    full_text += image_text_replacement(
                        processor, image_input, config, image_id
                    )
                    image_inputs.append(image_input)
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")

            batch_inputs.append(full_text)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs,
            truncation=True,
            max_length=max_truncation,
            add_special_tokens=False,
        )["input_ids"]
        if image_inputs:
            image_input = image_inputs[0]
            new_image_inputs = {
                "pixel_values": torch.cat(
                    [img["pixel_values"] for img in image_inputs], dim=0
                ),
            }
            if "pixel_attention_mask" in image_input:
                new_image_inputs["pixel_attention_mask"] = torch.cat(
                    [img["pixel_attention_mask"] for img in image_inputs], dim=0
                )
            if "image_sizes" in image_input:
                new_image_inputs["image_sizes"] = torch.cat(
                    [img["image_sizes"] for img in image_inputs], dim=0
                )
            image_inputs = new_image_inputs
        else:
            image_inputs = None
        return batch_tokenized_inputs, image_inputs


class PaliGemma(VlmCausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        self.processor = AutoProcessor.from_pretrained(
            model_id,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )

        super().__init__(
            config_cls=AutoConfig,
            model_cls=PaliGemmaForConditionalGeneration,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
            speculator=speculator,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

    @property
    def batch_type(self):
        return PaliGemmaBatch

    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.text_model.model.layers),
            model.text_model.model.num_key_value_heads,
            model.text_model.model.head_size,
        )

    def max_past(self) -> Optional[int]:
        return getattr(self.model.text_model, "max_past", None)